From ac23b653ecd8f73a2501e617be7cfb943b067945 Mon Sep 17 00:00:00 2001 From: endigma Date: Tue, 17 Mar 2026 10:59:17 +0000 Subject: [PATCH 01/52] wip: subs v2 --- examples/federation/go.mod | 2 + examples/federation/go.sum | 12 +- execution/engine/config_factory_federation.go | 8 +- execution/engine/engine_config.go | 23 +- execution/engine/engine_config_test.go | 4 +- execution/engine/execution_engine_test.go | 10 +- execution/go.mod | 7 +- execution/go.sum | 18 +- execution/graphql/result_writer.go | 2 +- execution/subscription/legacy_handler_test.go | 7 +- .../subscription/websocket/handler_test.go | 7 +- go.work.sum | 25 +- v2/go.mod | 22 +- v2/go.sum | 72 +- .../graphql_datasource/graphql_datasource.go | 27 +- .../graphql_datasource_test.go | 88 +- .../graphql_datasource/graphql_sse_handler.go | 275 -- .../graphql_sse_handler_test.go | 646 ---- .../graphql_subscription_client.go | 1038 ++----- .../graphql_subscription_client_test.go | 2594 +---------------- .../graphql_datasource/graphql_tws_handler.go | 379 --- .../graphql_tws_handler_test.go | 312 -- .../graphql_datasource/graphql_ws_handler.go | 301 -- .../graphql_ws_handler_test.go | 371 --- .../graphql_ws_proto_types.go | 47 - .../subscriptionclient/client.go | 103 + .../subscriptionclient/client_test.go | 289 ++ .../subscriptionclient/common/message.go | 27 + .../subscriptionclient/common/options.go | 55 + .../subscriptionclient/exports.go | 51 + .../protocol/graphql_transport_ws.go | 174 ++ .../protocol/graphql_transport_ws_test.go | 378 +++ .../subscriptionclient/protocol/graphql_ws.go | 175 ++ .../protocol/graphql_ws_test.go | 375 +++ .../subscriptionclient/protocol/protocol.go | 81 + .../subscriptionclient/transport/sse_conn.go | 190 ++ .../transport/sse_conn_test.go | 152 + .../transport/sse_transport.go | 259 ++ .../transport/sse_transport_test.go | 862 ++++++ .../subscriptionclient/transport/transport.go | 14 + .../transport/transport_test.go | 19 + .../subscriptionclient/transport/ws_conn.go | 309 ++ .../transport/ws_conn_test.go | 649 +++++ .../transport/ws_transport.go | 395 +++ .../transport/ws_transport_test.go | 1159 ++++++++ v2/pkg/engine/resolve/datasource.go | 5 - v2/pkg/engine/resolve/resolve.go | 1161 ++++---- v2/pkg/engine/resolve/resolve_test.go | 84 +- ..._test.go => resolver_subscription_test.go} | 85 +- v2/pkg/engine/resolve/response.go | 15 +- 50 files changed, 6878 insertions(+), 6485 deletions(-) delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go create mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go rename v2/pkg/engine/resolve/{event_loop_test.go => resolver_subscription_test.go} (69%) diff --git a/examples/federation/go.mod b/examples/federation/go.mod index 882ee7fd46..2c6235bf58 100644 --- a/examples/federation/go.mod +++ b/examples/federation/go.mod @@ -19,6 +19,7 @@ require ( github.com/bufbuild/protocompile v0.14.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coder/websocket v1.8.14 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -40,6 +41,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/r3labs/sse/v2 v2.8.1 // indirect + github.com/rs/xid v1.6.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sosodev/duration v1.3.1 // indirect diff --git a/examples/federation/go.sum b/examples/federation/go.sum index 4bcbb63885..b7957488be 100644 --- a/examples/federation/go.sum +++ b/examples/federation/go.sum @@ -20,8 +20,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -59,8 +59,7 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -121,6 +120,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sebdah/goldie/v2 v2.7.1 h1:PkBHymaYdtvEkZV7TmyqKxdmn5/Vcj+8TpATWZjnG5E= @@ -154,8 +155,7 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 h1:TGXDYfDhwFLFTuNuCwkuqXT5aXGz47zcurXLfTBS9w4= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99/go.mod h1:fUuOAUAXUFB/mlSkAaImGeE4A841AKR5dTMWhV4ibxI= diff --git a/execution/engine/config_factory_federation.go b/execution/engine/config_factory_federation.go index fca8b342b1..1c70f3d8a6 100644 --- a/execution/engine/config_factory_federation.go +++ b/execution/engine/config_factory_federation.go @@ -458,16 +458,16 @@ func (f *FederationEngineConfigFactory) subscriptionClient( switch subscriptionType { case SubscriptionTypeGraphQLTransportWS: graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - httpClient, - streamingClient, f.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(streamingClient), ) default: // for compatibility reasons we fall back to graphql-ws protocol graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - httpClient, - streamingClient, f.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(streamingClient), ) } diff --git a/execution/engine/engine_config.go b/execution/engine/engine_config.go index 551a9858d6..67a38039b7 100644 --- a/execution/engine/engine_config.go +++ b/execution/engine/engine_config.go @@ -143,7 +143,7 @@ func (d *graphqlDataSourceGenerator) Generate(dsID string, config graphql_dataso return nil, err } - return plan.NewDataSourceConfiguration[graphql_datasource.Configuration]( + return plan.NewDataSourceConfiguration( dsID, factory, &plan.DataSourceMetadata{ @@ -155,27 +155,16 @@ func (d *graphqlDataSourceGenerator) Generate(dsID string, config graphql_dataso } func (d *graphqlDataSourceGenerator) generateSubscriptionClient(httpClient *http.Client, definedOptions *dataSourceGeneratorOptions) (graphql_datasource.GraphQLSubscriptionClient, error) { - var graphqlSubscriptionClient graphql_datasource.GraphQLSubscriptionClient - switch definedOptions.subscriptionType { - case SubscriptionTypeGraphQLTransportWS: - graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient( - httpClient, - definedOptions.streamingClient, - nil, - ) - default: - // for compatibility reasons we fall back to graphql-ws protocol - graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient( - httpClient, - definedOptions.streamingClient, - nil, - ) - } + graphqlSubscriptionClient := definedOptions.subscriptionClientFactory.NewSubscriptionClient(d.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(definedOptions.streamingClient), + ) ok := graphql_datasource.IsDefaultGraphQLSubscriptionClient(graphqlSubscriptionClient) if !ok { return nil, errors.New("invalid subscriptionClient was instantiated") } + return graphqlSubscriptionClient, nil } diff --git a/execution/engine/engine_config_test.go b/execution/engine/engine_config_test.go index db6427d70b..334202f939 100644 --- a/execution/engine/engine_config_test.go +++ b/execution/engine/engine_config_test.go @@ -280,11 +280,11 @@ func TestGraphqlFieldConfigurationsGenerator_Generate(t *testing.T) { } -var mockSubscriptionClient = graphqlDataSource.NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, context.Background()) +var mockSubscriptionClient = graphqlDataSource.NewGraphQLSubscriptionClient(context.Background()) type MockSubscriptionClientFactory struct{} -func (m *MockSubscriptionClientFactory) NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...graphqlDataSource.Options) graphqlDataSource.GraphQLSubscriptionClient { +func (m *MockSubscriptionClientFactory) NewSubscriptionClient(engineCtx context.Context, options ...graphqlDataSource.SubscriptionClientOption) graphqlDataSource.GraphQLSubscriptionClient { return mockSubscriptionClient } diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 0f7c48ac00..1a9ca7a3c3 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -54,7 +54,8 @@ func mustConfiguration(t *testing.T, input graphql_datasource.ConfigurationInput func mustFactory(t testing.TB, httpClient *http.Client) plan.PlannerFactory[graphql_datasource.Configuration] { t.Helper() - factory, err := graphql_datasource.NewFactory(context.Background(), httpClient, graphql_datasource.NewGraphQLSubscriptionClient(httpClient, httpClient, context.Background())) + factory, err := graphql_datasource.NewFactory(context.Background(), httpClient, graphql_datasource.NewGraphQLSubscriptionClient(context.Background(), + graphql_datasource.WithUpgradeClient(httpClient), graphql_datasource.WithStreamingClient(httpClient))) require.NoError(t, err) return factory @@ -6077,10 +6078,9 @@ func newFederationEngineStaticConfig(ctx context.Context, setup *federationtesti return } - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) graphqlFactory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/execution/go.mod b/execution/go.mod index 8fb7c1fcb4..85fa1b2538 100644 --- a/execution/go.mod +++ b/execution/go.mod @@ -14,12 +14,12 @@ require ( github.com/sebdah/goldie/v2 v2.7.1 github.com/stretchr/testify v1.11.1 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/astjson v1.0.0 github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17 github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231 go.uber.org/atomic v1.11.0 - google.golang.org/grpc v1.68.1 + google.golang.org/grpc v1.71.0 google.golang.org/protobuf v1.36.9 ) @@ -29,6 +29,7 @@ require ( github.com/bufbuild/protocompile v0.14.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coder/websocket v1.8.14 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -41,6 +42,7 @@ require ( github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect @@ -65,6 +67,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/urfave/cli/v2 v2.27.7 // indirect + github.com/wundergraph/go-arena v1.1.0 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/execution/go.sum b/execution/go.sum index babde00b17..195d3b5f22 100644 --- a/execution/go.sum +++ b/execution/go.sum @@ -18,8 +18,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -60,6 +59,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -163,18 +164,24 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 h1:TGXDYfDhwFLFTuNuCwkuqXT5aXGz47zcurXLfTBS9w4= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99/go.mod h1:fUuOAUAXUFB/mlSkAaImGeE4A841AKR5dTMWhV4ibxI= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17 h1:GjO2E8LTf3U5JiQJCY4MmlRcAjVt7IvAbWFSgEjQdl8= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17/go.mod h1:7kt64e0LOLMBqOzrfu9PuLRn9cVT9YN1Bb3EennVtws= +github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231 h1:2C8LNFGs8MtI2yPy2/a2WRf9/X2FoMqXlEJkpTjvsTg= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= +go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= +go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -258,8 +265,7 @@ gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= +google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= diff --git a/execution/graphql/result_writer.go b/execution/graphql/result_writer.go index f97cbf721b..7e9b3f24a5 100644 --- a/execution/graphql/result_writer.go +++ b/execution/graphql/result_writer.go @@ -39,7 +39,7 @@ func (e *EngineResultWriter) Heartbeat() error { return nil } -func (e *EngineResultWriter) Close(_ resolve.SubscriptionCloseKind) { +func (e *EngineResultWriter) Error(_ []byte) { } diff --git a/execution/subscription/legacy_handler_test.go b/execution/subscription/legacy_handler_test.go index 1ca4933258..efc42b4354 100644 --- a/execution/subscription/legacy_handler_test.go +++ b/execution/subscription/legacy_handler_test.go @@ -567,10 +567,9 @@ func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*Ex engineConf := engine.NewConfiguration(chatSchema) - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) factory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/execution/subscription/websocket/handler_test.go b/execution/subscription/websocket/handler_test.go index acfe4d4e24..21849b271b 100644 --- a/execution/subscription/websocket/handler_test.go +++ b/execution/subscription/websocket/handler_test.go @@ -279,10 +279,9 @@ func setupExecutorPoolV2(t *testing.T, ctx context.Context, chatServerURL string engineConf := engine.NewConfiguration(chatSchema) engineConf.SetWebsocketBeforeStartHook(onBeforeStartHook) - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) factory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/go.work.sum b/go.work.sum index 1aecd8d220..f3972893ca 100644 --- a/go.work.sum +++ b/go.work.sum @@ -49,6 +49,7 @@ github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xlde github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySeApCX4GeOjPl9qhRF3QuIZq+Q= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/cgroups/v3 v3.0.2 h1:f5WFqIVSgo5IZmtTT3qVBo6TzI1ON6sycSBKkymb9L0= github.com/containerd/cgroups/v3 v3.0.2/go.mod h1:JUgITrzdFqp42uI2ryGA+ge0ap/nxzYgkGmIcetmErE= github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= @@ -56,6 +57,7 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= +github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -81,6 +83,8 @@ github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfU github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= @@ -116,8 +120,7 @@ github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= @@ -132,6 +135,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= @@ -141,6 +145,7 @@ github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47 github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2 h1:rcanfLhLDA8nozr/K289V1zcntHr3V+SHlXwzz1ZI2g= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= @@ -164,7 +169,10 @@ github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZ github.com/matryer/moq v0.2.7 h1:RtpiPUM8L7ZSCbSwK+QcZH/E9tgqAkFjKQxsRs25b4w= github.com/matryer/moq v0.5.2 h1:b2bsanSaO6IdraaIvPBzHnqcrkkQmk1/310HdT2nNQs= github.com/matryer/moq v0.5.2/go.mod h1:W/k5PLfou4f+bzke9VPXTbfJljxoeR1tLHigsmbshmU= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= @@ -195,6 +203,7 @@ github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.5.5 h1:2o3j7IrHbTIfxRtj4MPaXKeimuTYg49onNzNBZbwksM= github.com/posthog/posthog-go v1.5.5/go.mod h1:3RqUmSnPuwmeVj/GYrS75wNGqcAKdpODiwc83xZWgdE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= @@ -213,6 +222,7 @@ github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwy github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sanity-io/litter v1.5.8/go.mod h1:9gzJgR2i4ZpjZHsKvUXIRQVk7P+yM3e+jAF7bU2UI5U= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= github.com/shirou/gopsutil/v3 v3.24.3 h1:eoUGJSmdfLzJ3mxIhmOAhgKEKgQkeOwKpz1NbhVnuPE= @@ -227,10 +237,12 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -246,9 +258,12 @@ github.com/twmb/franz-go/pkg/kmsg v1.7.0/go.mod h1:se9Mjdt0Nwzc9lnjJ0HyDtLyBnaBD github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLVL040= +github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= @@ -289,6 +304,7 @@ go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= go.withmatt.com/connect-brotli v0.4.0 h1:7ObWkYmEbUXK3EKglD0Lgj0BBnnD3jNdAxeDRct3l8E= go.withmatt.com/connect-brotli v0.4.0/go.mod h1:c2eELz56za+/Mxh1yJrlglZ4VM9krpOCPqS2Vxf8NVk= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= @@ -326,6 +342,8 @@ golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -367,6 +385,7 @@ gonum.org/v1/plot v0.10.1 h1:dnifSs43YJuNMDzB7v8wV64O4ABBHReuAVAoBxqBqS4= gonum.org/v1/plot v0.10.1/go.mod h1:VZW5OlhkL1mysU9vaqNHnsy86inf6Ot+jB3r+BczCEo= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/v2/go.mod b/v2/go.mod index ad5d096fc1..534dfc0540 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -7,38 +7,36 @@ require ( github.com/bufbuild/protocompile v0.14.1 github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.3.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.14 github.com/davecgh/go-spew v1.1.1 - github.com/gobwas/ws v1.4.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-plugin v1.6.3 github.com/jensneuse/abstractlogger v0.0.4 - github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 + github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 github.com/jensneuse/diffview v1.0.0 github.com/kingledion/go-tools v0.6.0 github.com/kylelemons/godebug v1.1.0 github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d github.com/pkg/errors v0.9.1 github.com/r3labs/sse/v2 v2.8.1 + github.com/rs/xid v1.6.0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sebdah/goldie/v2 v2.7.1 github.com/stretchr/testify v1.11.1 - github.com/tidwall/gjson v1.17.0 - github.com/tidwall/sjson v1.0.4 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 github.com/wundergraph/astjson v1.1.0 github.com/wundergraph/go-arena v1.1.0 - go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 - go.uber.org/zap v1.26.0 golang.org/x/sync v0.17.0 golang.org/x/sys v0.37.0 golang.org/x/text v0.30.0 gonum.org/v1/gonum v0.14.0 - google.golang.org/grpc v1.68.1 + google.golang.org/grpc v1.71.0 google.golang.org/protobuf v1.36.9 gopkg.in/yaml.v2 v2.4.0 ) @@ -50,12 +48,11 @@ require ( github.com/dnephin/pflag v1.0.7 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/gobwas/httphead v0.1.0 // indirect - github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/hashicorp/go-hclog v0.14.1 // indirect + github.com/hashicorp/go-hclog v1.6.3 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -72,7 +69,10 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/urfave/cli/v2 v2.27.7 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect + go.opentelemetry.io/otel v1.36.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.36.0 // indirect go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.46.0 // indirect golang.org/x/term v0.36.0 // indirect diff --git a/v2/go.sum b/v2/go.sum index 13adfeb881..f8d6249fc8 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -15,8 +15,8 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -27,19 +27,17 @@ github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7c github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= -github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= -github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= -github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= -github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -53,8 +51,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= -github.com/hashicorp/go-hclog v0.14.1 h1:nQcJDQwIAGnmoUWp8ubocEX40cCml/17YkF6csQLReU= -github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-plugin v1.6.3 h1:xgHB+ZUSYeuJi96WtxEjzi23uh7YQpznjGh0U0UUrwg= github.com/hashicorp/go-plugin v1.6.3/go.mod h1:MRobyh+Wc/nYy1V4KAXUiYfzxoYhs7V1mlH1Z7iY2h0= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -63,8 +61,8 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/jensneuse/abstractlogger v0.0.4 h1:sa4EH8fhWk3zlTDbSncaWKfwxYM8tYSlQ054ETLyyQY= github.com/jensneuse/abstractlogger v0.0.4/go.mod h1:6WuamOHuykJk8zED/R0LNiLhWR6C7FIAo43ocUEB3mo= -github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 h1:E80wOd3IFQcoBxLkAUpUQ3BoGrZ4DxhQdP21+HH1s6A= -github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= +github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 h1:wo26fh6a6Za0cOMZIopD2sfH/kq83SJ89ixUWl7pCWc= +github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= github.com/jensneuse/diffview v1.0.0 h1:4b6FQJ7y3295JUHU3tRko6euyEboL825ZsXeZZM47Z4= github.com/jensneuse/diffview v1.0.0/go.mod h1:i6IacuD8LnEaPuiyzMHA+Wfz5mAuycMOf3R/orUY9y4= github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= @@ -83,11 +81,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= @@ -106,6 +105,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= @@ -126,17 +127,19 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= -github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= @@ -148,9 +151,19 @@ github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nX github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= +go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= +go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= +go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= +go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= -go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= @@ -158,8 +171,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= -go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -180,13 +193,16 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= @@ -214,8 +230,8 @@ gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= +google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index f4268d1f6a..e2de738182 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -425,7 +425,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { p.stopWithError(errors.WithStack(fmt.Errorf("ConfigureSubscription: failed to marshal header: %w", err))) return plan.SubscriptionConfiguration{} } - if err == nil && len(header) != 0 && !bytes.Equal(header, literal.NULL) { + if len(header) != 0 && !bytes.Equal(header, literal.NULL) { input = httpclient.SetInputHeader(input, header) } @@ -1960,8 +1960,6 @@ func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (d type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - Unsubscribe(id uint64) } type GraphQLSubscriptionOptions struct { @@ -1996,25 +1994,6 @@ type SubscriptionSource struct { subscriptionOnStartFns []SubscriptionOnStartFn } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { - var options GraphQLSubscriptionOptions - err := json.Unmarshal(input, &options) - if err != nil { - return err - } - options.Header = headers - if options.Body.Query == "" { - return resolve.ErrUnableToResolve - } - return s.client.SubscribeAsync(ctx, id, options, updater) -} - -// AsyncStop stops the subscription with the given id. AsyncStop is only effective when netPoll is enabled -// because without netPoll we manage the lifecycle of the connection in the subscription client. -func (s *SubscriptionSource) AsyncStop(id uint64) { - s.client.Unsubscribe(id) -} - // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions @@ -2029,10 +2008,6 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, in return s.client.Subscribe(ctx, options, updater) } -var ( - dataSouceName = []byte("graphql") -) - // SubscriptionOnStart is called when a subscription is started. // Hooks are invoked sequentially, short-circuiting on the first error. func (s *SubscriptionSource) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) error { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 05b07df9e1..9432dab865 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -4006,7 +4006,7 @@ func TestGraphQLDataSource(t *testing.T) { Trigger: resolve.GraphQLSubscriptionTrigger{ Input: []byte(`{"url":"wss://swapi.com/graphql","body":{"query":"subscription{remainingJedis}"}}`), Source: &SubscriptionSource{ - client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), }, PostProcessing: DefaultPostProcessingConfiguration, SourceName: "ds-id", @@ -4049,7 +4049,7 @@ func TestGraphQLDataSource(t *testing.T) { }, ), Source: &SubscriptionSource{ - client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), }, PostProcessing: DefaultPostProcessingConfiguration, SourceName: "ds-id", @@ -8384,14 +8384,16 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap type testSubscriptionUpdaterChan struct { updates chan string complete chan struct{} - closed chan resolve.SubscriptionCloseKind + errors chan []byte + done chan struct{} } func newTestSubscriptionUpdaterChan() *testSubscriptionUpdaterChan { return &testSubscriptionUpdaterChan{ updates: make(chan string), complete: make(chan struct{}), - closed: make(chan resolve.SubscriptionCloseKind), + errors: make(chan []byte, 1), + done: make(chan struct{}), } } @@ -8408,7 +8410,7 @@ func (t *testSubscriptionUpdaterChan) UpdateSubscription(id resolve.Subscription } // empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { +func (t *testSubscriptionUpdaterChan) CloseSubscription(id resolve.SubscriptionIdentifier) { } // empty method to satisfy the interface, not used in this tests @@ -8420,8 +8422,12 @@ func (t *testSubscriptionUpdaterChan) Complete() { close(t.complete) } -func (t *testSubscriptionUpdaterChan) Close(kind resolve.SubscriptionCloseKind) { - t.closed <- kind +func (t *testSubscriptionUpdaterChan) Error(data []byte) { + t.errors <- data +} + +func (t *testSubscriptionUpdaterChan) Done() { + close(t.done) } func (t *testSubscriptionUpdaterChan) AwaitUpdateWithT(tt *testing.T, timeout time.Duration, f func(t *testing.T, update string), msgAndArgs ...any) { @@ -8435,24 +8441,25 @@ func (t *testSubscriptionUpdaterChan) AwaitUpdateWithT(tt *testing.T, timeout ti } } -func (t *testSubscriptionUpdaterChan) AwaitClose(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { +func (t *testSubscriptionUpdaterChan) AwaitError(tt *testing.T, timeout time.Duration, msgAndArgs ...any) []byte { tt.Helper() select { - case <-t.closed: + case data := <-t.errors: + return data case <-time.After(timeout): - require.Fail(tt, "updater not closed before timeout", msgAndArgs...) + require.Fail(tt, "updater error not received before timeout", msgAndArgs...) + return nil } } -func (t *testSubscriptionUpdaterChan) AwaitCloseKind(tt *testing.T, timeout time.Duration, expectedCloseKind resolve.SubscriptionCloseKind, msgAndArgs ...any) { +func (t *testSubscriptionUpdaterChan) AwaitDone(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { tt.Helper() select { - case closeKind := <-t.closed: - require.Equal(tt, expectedCloseKind, closeKind, msgAndArgs...) + case <-t.done: case <-time.After(timeout): - require.Fail(tt, "updater not closed before timeout", msgAndArgs...) + require.Fail(tt, "updater not done before timeout", msgAndArgs...) } } @@ -8470,8 +8477,8 @@ func (t *testSubscriptionUpdaterChan) AwaitComplete(tt *testing.T, timeout time. // It's faster, more ergonomic and more reliable. See SSE handler tests for usage examples. type testSubscriptionUpdater struct { updates []string + errors []string done bool - closed bool mux sync.Mutex } @@ -8496,6 +8503,27 @@ func (t *testSubscriptionUpdater) AwaitUpdates(tt *testing.T, timeout time.Durat } } +func (t *testSubscriptionUpdater) AwaitErrors(tt *testing.T, timeout time.Duration, count int) { + tt.Helper() + + ticker := time.NewTicker(timeout) + defer ticker.Stop() + for { + time.Sleep(10 * time.Millisecond) + select { + case <-ticker.C: + tt.Fatalf("timed out waiting for errors") + default: + t.mux.Lock() + if len(t.errors) == count { + t.mux.Unlock() + return + } + t.mux.Unlock() + } + } +} + func (t *testSubscriptionUpdater) AwaitDone(tt *testing.T, timeout time.Duration) { tt.Helper() @@ -8535,14 +8563,20 @@ func (t *testSubscriptionUpdater) Complete() { t.done = true } -func (t *testSubscriptionUpdater) Close(kind resolve.SubscriptionCloseKind) { +func (t *testSubscriptionUpdater) Error(data []byte) { t.mux.Lock() defer t.mux.Unlock() - t.closed = true + t.errors = append(t.errors, string(data)) +} + +func (t *testSubscriptionUpdater) Done() { + t.mux.Lock() + defer t.mux.Unlock() + t.done = true } // empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdater) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { +func (t *testSubscriptionUpdater) CloseSubscription(id resolve.SubscriptionIdentifier) { } // empty method to satisfy the interface, not used in this tests @@ -8591,8 +8625,7 @@ func TestSubscriptionSource_Start(t *testing.T) { } newSubscriptionSource := func(ctx context.Context) SubscriptionSource { - httpClient := http.Client{} - subscriptionSource := SubscriptionSource{client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx)} + subscriptionSource := SubscriptionSource{client: NewGraphQLSubscriptionClient(ctx)} return subscriptionSource } @@ -8631,9 +8664,9 @@ func TestSubscriptionSource_Start(t *testing.T) { chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.updates[0]) + updater.AwaitErrors(t, time.Second, 1) + assert.Len(t, updater.errors, 1) + assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.errors[0]) updater.AwaitDone(t, time.Second) }) @@ -8719,9 +8752,8 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { } newSubscriptionSource := func(ctx context.Context) SubscriptionSource { - httpClient := http.Client{} subscriptionSource := SubscriptionSource{ - client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), } return subscriptionSource } @@ -8737,9 +8769,9 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.updates[0]) + updater.AwaitErrors(t, time.Second, 1) + assert.Len(t, updater.errors, 1) + assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.errors[0]) updater.AwaitDone(t, time.Second) assert.Equal(t, true, updater.done) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go deleted file mode 100644 index 512d7ca4b0..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go +++ /dev/null @@ -1,275 +0,0 @@ -package graphql_datasource - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "math" - "net/http" - - "github.com/buger/jsonparser" - log "github.com/jensneuse/abstractlogger" - "github.com/r3labs/sse/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -var ( - headerData = []byte("data:") - headerEvent = []byte("event:") - - eventTypeComplete = []byte("complete") - eventTypeNext = []byte("next") -) - -type gqlSSEConnectionHandler struct { - conn *http.Client - requestContext, engineContext context.Context - log log.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater -} - -func newSSEConnectionHandler(requestContext, engineContext context.Context, conn *http.Client, updater resolve.SubscriptionUpdater, options GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler { - return &gqlSSEConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: l, - updater: updater, - options: options, - } -} - -func (h *gqlSSEConnectionHandler) StartBlocking() { - defer h.updater.Close(resolve.SubscriptionCloseKindNormal) - - resp, err := h.performSubscriptionRequest() - if err != nil { - h.log.Error("failed to perform subscription request", log.Error(err)) - - if h.requestContext.Err() != nil { - // request context was canceled do not send an error as channel will be closed - return - } - - h.updater.Update([]byte(internalError)) - - return - } - - defer func() { - _ = resp.Body.Close() - }() - - reader := sse.NewEventStreamReader(resp.Body, math.MaxInt) - - for { - select { - case <-h.requestContext.Done(): - return - case <-h.engineContext.Done(): - return - default: - } - - msg, err := reader.ReadEvent() - if err != nil { - if err == io.EOF { - return - } - - h.log.Error("failed to read event", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - - if len(msg) == 0 { - continue - } - - // normalize the crlf to lf to make it easier to split the lines. - // split the line by "\n" or "\r", per the spec. - lines := bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) - for _, line := range lines { - switch { - case bytes.HasPrefix(line, headerData): - data := trim(line[len(headerData):]) - - if len(data) == 0 { - continue - } - - if h.requestContext.Err() != nil { - // request context was canceled do not send an error as channel will be closed - return - } - - h.updater.Update(data) - case bytes.HasPrefix(line, headerEvent): - event := trim(line[len(headerEvent):]) - - switch { - case bytes.Equal(event, eventTypeComplete): - h.updater.Complete() - return - case bytes.Equal(event, eventTypeNext): - continue - } - case bytes.HasPrefix(msg, []byte(":")): - // according to the spec, we ignore messages starting with a colon - continue - default: - // ideally we should not get here, or if we do, we should ignore it - // but some providers send a json object with the error messages, without the event header - - // check for errors which came without event header - data := trim(line) - - val, valueType, _, err := jsonparser.Get(data, "errors") - switch err { - case jsonparser.KeyPathNotFoundError: - continue - case jsonparser.MalformedJsonError: - // ignore garbage - continue - case nil: - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, val, "errors") - if err != nil { - h.log.Error("failed to set errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - return - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, val, "errors", "[0]") - if err != nil { - h.log.Error("failed to set errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - return - default: - // don't crash on unexpected payloads from upstream - h.log.Error(fmt.Sprintf("unexpected value type: %d", valueType)) - h.updater.Update([]byte(internalError)) - return - } - - default: - h.log.Error("failed to parse errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - } - } - } -} - -func trim(data []byte) []byte { - // remove the leading space - data = bytes.TrimLeft(data, " \t") - - // remove the trailing new line - data = bytes.TrimRight(data, "\n") - - return data -} - -func (h *gqlSSEConnectionHandler) performSubscriptionRequest() (*http.Response, error) { - var req *http.Request - var err error - - // default to GET requests when SSEMethodPost is not enabled in the SubscriptionConfiguration - if h.options.SSEMethodPost { - req, err = h.buildPOSTRequest(h.requestContext) - } else { - req, err = h.buildGETRequest(h.requestContext) - } - - if err != nil { - return nil, err - } - - resp, err := h.conn.Do(req) - if err != nil { - return nil, err - } - - switch resp.StatusCode { - case http.StatusOK: - return resp, nil - default: - return nil, fmt.Errorf("failed to connect to stream unexpected resp status code: %d", resp.StatusCode) - } -} - -func (h *gqlSSEConnectionHandler) buildGETRequest(ctx context.Context) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, "GET", h.options.URL, nil) - if err != nil { - return nil, err - } - - if h.options.Header != nil { - req.Header = h.options.Header - } - - query := req.URL.Query() - query.Add("query", h.options.Body.Query) - - if h.options.Body.Variables != nil { - variables, _ := h.options.Body.Variables.MarshalJSON() - - query.Add("variables", string(variables)) - } - - if h.options.Body.OperationName != "" { - query.Add("operationName", h.options.Body.OperationName) - } - - if h.options.Body.Extensions != nil { - extensions, _ := h.options.Body.Extensions.MarshalJSON() - - query.Add("extensions", string(extensions)) - } - - req.URL.RawQuery = query.Encode() - h.setSSEHeaders(req) - - return req, nil -} - -func (h *gqlSSEConnectionHandler) buildPOSTRequest(ctx context.Context) (*http.Request, error) { - body, err := json.Marshal(h.options.Body) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", h.options.URL, bytes.NewBuffer(body)) - if err != nil { - return nil, err - } - - if h.options.Header != nil { - req.Header = h.options.Header - } - - req.Header.Set("Content-Type", "application/json") - h.setSSEHeaders(req) - return req, nil -} - -// setSSEHeaders sets the headers required for SSE for both GET and POST requests -func (h *gqlSSEConnectionHandler) setSSEHeaders(req *http.Request) { - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Cache-Control", "no-cache") -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go deleted file mode 100644 index 4c50f19176..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go +++ /dev/null @@ -1,646 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestGraphQLSubscriptionClientSubscribe_SSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription {messageAdded(roomName: \"room\"){text}}", urlQuery.Get("query")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_RequestAbort(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - ctx, clientCancel := context.WithCancel(context.Background()) - // cancel after start the request - clientCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, t.Context(), - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: "http://dummy", - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitClose(t, time.Second) -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_POST(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - postReqBody := GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - } - expectedReqBody, err := json.Marshal(postReqBody) - assert.NoError(t, err) - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - - actualReqBody, err := io.ReadAll(r.Body) - assert.NoError(t, err) - assert.Equal(t, expectedReqBody, actualReqBody) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: complete\n\n") - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err = client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: postReqBody, - UseSSE: true, - SSEMethodPost: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, update) - }) - - updater.AwaitComplete(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_WithEvents(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "event: next\ndata: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: next\ndata: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: complete\n\n") - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, update) - }) - - updater.AwaitComplete(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Error(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Error_Without_Header(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - testCases := []struct { - name string - errorMessage string - expectedErr string - }{ - { - name: "object_error_value", - errorMessage: `{"errors":{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]},"data":null}`, - expectedErr: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, - }, - { - name: "list_error_value", - errorMessage: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}],"data":null}`, - expectedErr: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, - }, - { - name: "string_error_value", - errorMessage: `{"errors": "some string error"}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "number_error_value", - errorMessage: `{"errors": 123}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "boolean_true_error_value", - errorMessage: `{"errors": true}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "null_error_value", - errorMessage: `{"errors": null}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - // Send the malformed error message WITHOUT the "data:" prefix - // This triggers the error parsing logic in the default case - _, _ = fmt.Fprintf(w, "%s\n\n", tc.errorMessage) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, tc.expectedErr, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() - }) - } -} - -func TestGraphQLSubscriptionClientSubscribe_QueryParams(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription($a: Int!){countdown(from: $a)}", urlQuery.Get("query")) - assert.Equal(t, "CountDown", urlQuery.Get("operationName")) - assert.Equal(t, `{"a":5}`, urlQuery.Get("variables")) - assert.Equal(t, `{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`, urlQuery.Get("extensions")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"countdown":5}}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"countdown":5}}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestBuildPOSTRequestSSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - subscriptionOptions := GraphQLSubscriptionOptions{ - URL: "test", - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - } - - h := gqlSSEConnectionHandler{ - options: subscriptionOptions, - } - - req, err := h.buildPOSTRequest(context.Background()) - assert.NoError(t, err) - - expectedReqBody, err := json.Marshal(subscriptionOptions.Body) - assert.NoError(t, err) - - assert.Equal(t, http.MethodPost, req.Method) - - actualReqBody, err := io.ReadAll(req.Body) - assert.NoError(t, err) - assert.Equal(t, expectedReqBody, actualReqBody) -} - -func TestBuildGETRequestSSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - subscriptionOptions := GraphQLSubscriptionOptions{ - URL: "test", - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - } - - h := gqlSSEConnectionHandler{ - options: subscriptionOptions, - } - - req, err := h.buildGETRequest(context.Background()) - assert.NoError(t, err) - - assert.Equal(t, http.MethodGet, req.Method) - - urlQuery := req.URL.Query() - assert.Equal(t, subscriptionOptions.Body.Query, urlQuery.Get("query")) - assert.Equal(t, subscriptionOptions.Body.OperationName, urlQuery.Get("operationName")) - - assert.Equal(t, string(subscriptionOptions.Body.Variables), urlQuery.Get("variables")) - assert.Equal(t, string(subscriptionOptions.Body.Extensions), urlQuery.Get("extensions")) - -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription {messageAdded(roomName: \"room\"){text}}", urlQuery.Get("query")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - // Kill the upstream server. We should catch this event as an "unexpected EOF" - // error and return an error message to the subscriber. - h, ok := w.(http.Hijacker) - require.True(t, ok) - rawConn, _, err := h.Hijack() - require.NoError(t, err) - _ = rawConn.Close() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"errors":[{"message":"internal error"}]}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c8a08df03f..4d74ac2d1c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -2,930 +2,332 @@ package graphql_datasource import ( "context" - "crypto/rand" - "crypto/sha1" - "crypto/tls" - "encoding/base64" + "encoding/json" "errors" "fmt" - "io" - "net" "net/http" - "net/http/httptrace" - "strings" - "sync" - "syscall" "time" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" + "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" - "go.uber.org/atomic" + client "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/netpoll" ) -const ( - // The time to write a message to the server connection before timing out - writeTimeout = 10 * time.Second - // The time to wait for a connection ack message from the server before timing out - ackWaitTimeout = 30 * time.Second -) - -type netPollState struct { - // connections is a map of fd -> connection to keep track of all active connections - connections map[int]*connection - hasConnections atomic.Bool - // triggers is a map of subscription id -> fd to easily look up the connection for a subscription id - triggers map[uint64]int - - // clientUnsubscribe is a channel to signal to the netPoll run loop that a client needs to be unsubscribed - clientUnsubscribe chan uint64 - // addConn is a channel to signal to the netPoll run loop that a new connection needs to be added - addConn chan *connection - // waitForEventsTicker is the ticker for the netPoll run loop - // it is used to prevent busy waiting and to limit the CPU usage - // instead of polling the netPoll instance all the time, we wait until the next tick to throttle the netPoll loop - waitForEventsTicker *time.Ticker - - // waitForEventsTick is the channel to receive the tick from the waitForEventsTicker - waitForEventsTick <-chan time.Time -} - -// subscriptionClient allows running multiple subscriptions via the same WebSocket either SSE connection -// It takes care of de-duplicating connections to the same origin under certain circumstances -// If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used -type subscriptionClient struct { - streamingClient *http.Client - httpClient *http.Client - - useHttpClientWithSkipRoundTrip bool - - engineCtx context.Context - log abstractlogger.Logger - hashPool sync.Pool - onWsConnectionInitCallback *OnWsConnectionInitCallback - - readTimeout time.Duration - pingInterval time.Duration - frameTimeout time.Duration - pingTimeout time.Duration - - netPoll netpoll.Poller - netPollConfig NetPollConfiguration - netPollState *netPollState -} - -func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - if options.UseSSE { - return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) - } - - if strings.HasPrefix(options.URL, "https") { - options.URL = "wss" + options.URL[5:] - } else if strings.HasPrefix(options.URL, "http") { - options.URL = "ws" + options.URL[4:] - } - - return c.asyncSubscribeWS(ctx.Context(), c.engineCtx, id, options, updater) -} - -func (c *subscriptionClient) Unsubscribe(id uint64) { - // if we don't have netPoll, we don't have a channel consumer of the clientUnsubscribe channel - // we have to return to prevent a deadlock - if c.netPoll == nil { - return - } - c.netPollState.clientUnsubscribe <- id -} +// SubscriptionClientConfig holds the subscription client configuration. +type SubscriptionClientConfig struct { + UpgradeClient *http.Client + StreamingClient *http.Client + Logger abstractlogger.Logger -type InvalidWsSubprotocolError struct { - InvalidProtocol string + // Timeouts + PingInterval time.Duration + PingTimeout time.Duration + AckTimeout time.Duration + WriteTimeout time.Duration + ReadLimit int64 } -func (e InvalidWsSubprotocolError) Error() string { - return fmt.Sprintf("provided websocket subprotocol '%s' is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subsciptions with the mentioned subprotocols", e.InvalidProtocol) -} +func defaultSubscriptionClientConfig() *SubscriptionClientConfig { + return &SubscriptionClientConfig{ + UpgradeClient: http.DefaultClient, + StreamingClient: http.DefaultClient, + Logger: abstractlogger.NoopLogger, -func NewInvalidWsSubprotocolError(invalidProtocol string) InvalidWsSubprotocolError { - return InvalidWsSubprotocolError{ - InvalidProtocol: invalidProtocol, + PingInterval: 30 * time.Second, + PingTimeout: 10 * time.Second, + AckTimeout: 30 * time.Second, } } -type Options func(options *opts) - -func WithLogger(log abstractlogger.Logger) Options { - return func(options *opts) { - options.log = log - } -} +// SubscriptionClientOption configures the subscription client. +type SubscriptionClientOption func(*SubscriptionClientConfig) -func WithReadTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.readTimeout = timeout +// WithUpgradeClient sets the HTTP client used for WebSocket upgrade requests. +func WithUpgradeClient(c *http.Client) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + if c != nil { + cfg.UpgradeClient = c + } } } -func WithPingInterval(interval time.Duration) Options { - return func(options *opts) { - options.pingInterval = interval +// WithStreamingClient sets the HTTP client used for SSE requests. +// This client should have appropriate timeouts for long-lived connections. +func WithStreamingClient(c *http.Client) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + if c != nil { + cfg.StreamingClient = c + } } } -func WithFrameTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.frameTimeout = timeout +// WithLogger sets the logger for the client and its transports. +// If not set, logging is disabled (silent operation). +func WithLogger(log abstractlogger.Logger) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.Logger = log } } -func WithPingTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.pingTimeout = timeout +// WithPingInterval sets the interval between ping messages for connection health checks. +// Only applies to graphql-transport-ws protocol (legacy graphql-ws uses server-initiated keepalive). +// Default: 30s. Set to 0 to disable client-initiated pings. +func WithPingInterval(d time.Duration) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.PingInterval = d } } -type NetPollConfiguration struct { - // Enable can be set to true to enable netPoll - Enable bool - // BufferSize defines the size of the buffer for the netPoll loop - BufferSize int - // WaitForNumEvents defines how many events are waited for in the netPoll loop before TickInterval cancels the wait - WaitForNumEvents int - // MaxEventWorkers defines the parallelism of how many connections can be handled at the same time - // The higher the number, the more CPU is used. - MaxEventWorkers int - // TickInterval defines the time between each netPoll loop when WaitForNumEvents is not reached - TickInterval time.Duration -} - -func (e *NetPollConfiguration) ApplyDefaults() { - e.Enable = true - - if e.BufferSize == 0 { - e.BufferSize = 1024 - } - if e.MaxEventWorkers == 0 { - e.MaxEventWorkers = 6 - } - if e.WaitForNumEvents == 0 { - e.WaitForNumEvents = 1024 - } - if e.TickInterval == 0 { - e.TickInterval = time.Millisecond * 100 +// WithPingTimeout sets the maximum time to wait for a pong response. +// If no pong is received within this duration, the connection is considered dead. +// Default: 10s. +func WithPingTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.PingTimeout = d } } -func WithNetPollConfiguration(config NetPollConfiguration) Options { - return func(options *opts) { - options.netPollConfiguration = config +// WithAckTimeout sets the maximum time to wait for connection_ack after connection_init. +// Default: 30s. +func WithAckTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.AckTimeout = d } } -type opts struct { - readTimeout time.Duration - pingInterval time.Duration - pingTimeout time.Duration - frameTimeout time.Duration - log abstractlogger.Logger - onWsConnectionInitCallback *OnWsConnectionInitCallback - netPollConfiguration NetPollConfiguration -} - -// GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. -// This can be very handy for testing purposes. -type GraphQLSubscriptionClientFactory interface { - NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient -} - -type DefaultSubscriptionClientFactory struct{} - -func (d *DefaultSubscriptionClientFactory) NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient { - return NewGraphQLSubscriptionClient(httpClient, streamingClient, engineCtx, options...) -} - -func IsDefaultGraphQLSubscriptionClient(client GraphQLSubscriptionClient) bool { - _, ok := client.(*subscriptionClient) - return ok -} - -func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient { - - // Defaults - op := &opts{ - readTimeout: 5 * time.Second, - pingInterval: 15 * time.Second, - pingTimeout: 30 * time.Second, - frameTimeout: 100 * time.Millisecond, - log: abstractlogger.NoopLogger, - } - - op.netPollConfiguration.ApplyDefaults() - - for _, option := range options { - option(op) - } - - client := &subscriptionClient{ - httpClient: httpClient, - streamingClient: streamingClient, - engineCtx: engineCtx, - log: op.log, - readTimeout: op.readTimeout, - pingInterval: op.pingInterval, - pingTimeout: op.pingTimeout, - frameTimeout: op.frameTimeout, - hashPool: sync.Pool{ - New: func() interface{} { - return xxhash.New() - }, - }, - onWsConnectionInitCallback: op.onWsConnectionInitCallback, - netPollConfig: op.netPollConfiguration, +// WithWriteTimeout sets the timeout for WebSocket write operations (subscribe, unsubscribe, ping, pong). +// Default: 5s. +func WithWriteTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.WriteTimeout = d } - if op.netPollConfiguration.Enable { - client.netPollState = &netPollState{ - connections: make(map[int]*connection), - triggers: make(map[uint64]int), - clientUnsubscribe: make(chan uint64, op.netPollConfiguration.BufferSize), - addConn: make(chan *connection, op.netPollConfiguration.BufferSize), - // this is not needed, but we want to make it explicit that we're starting with nil as the tick channel - // reading from nil channels blocks forever, which allows us to prevent the netPoll loop from starting - // once we add the first connection, we start the ticker and set the tick channel - // after the last connection is removed, we set the tick channel to nil again - // this way we can start and stop the epoll loop dynamically - waitForEventsTick: nil, - } - - // ignore error is ok, it means that netPoll is not supported, which is handled gracefully by the client - poller, _ := netpoll.NewPoller(op.netPollConfiguration.BufferSize, op.netPollConfiguration.TickInterval) - if poller != nil { - client.netPoll = poller - go client.runNetPoll(engineCtx) - } - } - return client } -type connection struct { - id uint64 - fd int - netConn net.Conn - handler ConnectionHandler - shouldClose bool -} - -// Subscribe initiates a new GraphQL Subscription with the origin -// If an existing WS connection with the same ID (Hash) exists, it is being re-used -// If connection protocol is SSE, a new connection is always created -// If no connection exists, the client initiates a new one -func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - if options.UseSSE { - return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) +// WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. +// Default: 1MB. +func WithReadLimit(n int64) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.ReadLimit = n } - - return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } -func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - if c.streamingClient == nil { - return fmt.Errorf("streaming http client is nil") - } - - handler := newSSEConnectionHandler(requestContext, engineContext, c.streamingClient, updater, options, c.log) - - go handler.StartBlocking() - - return nil +// subscriptionClientV2 implements GraphQLSubscriptionClient using the new +// channel-based subscription client. +type subscriptionClientV2 struct { + client *client.Client } -func (c *subscriptionClient) subscribeWS(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - options.pingInterval = c.pingInterval - options.pingTimeout = c.pingTimeout - - if c.httpClient == nil { - return fmt.Errorf("http client is nil") +// NewGraphQLSubscriptionClient creates a new subscription client. +func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClientOption) GraphQLSubscriptionClient { + cfg := defaultSubscriptionClientConfig() + for _, opt := range opts { + opt(cfg) } - conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) - if err != nil { - return err + return &subscriptionClientV2{ + client: client.New(ctx, client.Config{ + UpgradeClient: cfg.UpgradeClient, + StreamingClient: cfg.StreamingClient, + Logger: cfg.Logger, + PingInterval: cfg.PingInterval, + PingTimeout: cfg.PingTimeout, + AckTimeout: cfg.AckTimeout, + WriteTimeout: cfg.WriteTimeout, + ReadLimit: cfg.ReadLimit, + }), } - - go func() { - err := conn.handler.StartBlocking() - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { - return - } - c.log.Error("subscriptionClient.subscribeWS", abstractlogger.Error(err)) - } - }() - - return nil } -func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext context.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - options.pingInterval = c.pingInterval - options.pingTimeout = c.pingTimeout - - if c.httpClient == nil { - return fmt.Errorf("http client is nil") - } - - conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) +// Subscribe implements GraphQLSubscriptionClient. +// It bridges the channel-based new client API to the callback-based updater interface. +func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { + opts, req, err := convertToClientOptions(options) if err != nil { return err } - if c.netPoll == nil { - go func() { - err := conn.handler.StartBlocking() - if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) - } - }() - return nil - } - - // if we have netPoll, we need to add the connection to the netPoll - - // init the subscription - err = conn.handler.Subscribe() + msgCh, cancel, err := c.client.Subscribe(ctx.Context(), req, opts) if err != nil { return err } - var fd int + go c.readLoop(ctx.Context(), msgCh, cancel, updater) - // we have to check if the connection is a tls connection to get the underlying net.Conn - if tlsConn, ok := conn.netConn.(*tls.Conn); ok { - netConn := tlsConn.NetConn() - fd = netpoll.SocketFD(netConn) - } else { - fd = netpoll.SocketFD(conn.netConn) - } - - if fd == 0 { - c.log.Error("failed to get file descriptor from connection. This indicates a problem with the netPoll implementation") - return fmt.Errorf("failed to get file descriptor from connection") - } - - conn.id, conn.fd = id, fd - // submit the connection to the netPoll run loop - c.netPollState.addConn <- conn return nil } -type UpgradeRequestError struct { - URL string - StatusCode int -} - -func (u *UpgradeRequestError) Error() string { - return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", u.URL, u.StatusCode) -} - -func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (*connection, error) { - // Any failure will be stored here, needed for deferred body closer. - var err error - - conn, subProtocol, err := c.dial(requestContext, options) - if err != nil { - return nil, err - } - - if conn == nil { - return nil, fmt.Errorf("failed to dial connection") - } - - // conn is not nil. Any errored return below could lead to a leaking connection. - // To avoid this, close connection if failure happened. - defer func() { - if err != nil { - conn.Close() - } - }() - - initMsg, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header) - if err != nil { - return nil, err - } - - if len(options.InitialPayload) > 0 { - initMsg, err = jsonparser.Set(initMsg, options.InitialPayload, "payload") - if err != nil { - return nil, err - } - } - - if options.Body.Extensions != nil { - initMsg, err = jsonparser.Set(initMsg, options.Body.Extensions, "payload", "extensions") - if err != nil { - return nil, err - } - } - - // init + ack - if err = conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return nil, err - } - err = wsutil.WriteClientText(conn, initMsg) - if err != nil { - return nil, err - } - - if err = waitForAck(conn, c.readTimeout, writeTimeout); err != nil { - return nil, err - } - - switch subProtocol { - case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil - case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil - default: - return nil, NewInvalidWsSubprotocolError(subProtocol) - } -} - -func (c *subscriptionClient) dial(ctx context.Context, options GraphQLSubscriptionOptions) (conn net.Conn, subProtocol string, err error) { - subProtocols := []string{ProtocolGraphQLTWS, ProtocolGraphQLWS} - if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { - subProtocols = []string{options.WsSubProtocol} - } +// readLoop bridges the channel-based API to the callback-based updater. +func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *client.Message, cancel func(), updater resolve.SubscriptionUpdater) { + defer cancel() - clientTrace := &httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - conn = info.Conn - }, - } - clientTraceCtx := httptrace.WithClientTrace(ctx, clientTrace) - u := options.URL - if strings.HasPrefix(options.URL, "wss") { - u = "https" + options.URL[3:] - } else if strings.HasPrefix(options.URL, "ws") { - u = "http" + options.URL[2:] - } - req, err := http.NewRequestWithContext(clientTraceCtx, http.MethodGet, u, nil) - if err != nil { - return nil, "", err - } - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - if options.Header != nil { - req.Header = options.Header - } - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ",")) - req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, "", err - } - - req.Header.Set("Sec-WebSocket-Key", challengeKey) - - upgradeResponse, err := c.httpClient.Do(req) - if err != nil { - return nil, "", err - } - - // On failed upgrades, we close the body without transferring ownership to the caller. - - if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { - // Drain to EOF to allow connection reuse by net/http. - _, _ = io.Copy(io.Discard, upgradeResponse.Body) - upgradeResponse.Body.Close() - return nil, "", &UpgradeRequestError{ - URL: u, - StatusCode: upgradeResponse.StatusCode, - } - } - - accept := computeAcceptKey(challengeKey) - if upgradeResponse.Header.Get("Sec-WebSocket-Accept") != accept { - _, _ = io.Copy(io.Discard, upgradeResponse.Body) - upgradeResponse.Body.Close() - return nil, "", fmt.Errorf("invalid Sec-WebSocket-Accept") - } - - subProtocol = subProtocols[0] - if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" { - subProtocol = upgradeResponse.Header.Get("Sec-WebSocket-Protocol") - if subProtocol == "" { - subProtocol = ProtocolGraphQLWS - } - } - - return conn, subProtocol, nil -} - -func generateChallengeKey() (string, error) { - p := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, p); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(p), nil -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func computeAcceptKey(challengeKey string) string { - h := sha1.New() // #nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 - h.Write([]byte(challengeKey)) - h.Write(keyGUID) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url string, header http.Header) ([]byte, error) { - if c.onWsConnectionInitCallback == nil { - return connectionInitMessage, nil - } - - callback := *c.onWsConnectionInitCallback - - payload, err := callback(ctx, url, header) - if err != nil { - return nil, err - } - - if len(payload) == 0 { - return connectionInitMessage, nil - } - - msg, err := jsonparser.Set(connectionInitMessage, payload, "payload") - if err != nil { - return nil, err - } - - return msg, nil -} - -type ConnectionHandler interface { - // StartBlocking starts the connection handler and blocks until the connection is closed - // Only used as fallback when epoll is not available - StartBlocking() error - // HandleMessage handles the incoming message from the connection - HandleMessage(data []byte) (done bool) - // Ping sends a ping message to the upstream server to keep the connection alive. - // Implementers must keep track of the last ping time to initiate a connection shutdown - // if the upstream is not sending a pong. - Ping() - // ServerClose closes the connection from the server side - ServerClose() - // ClientClose closes the connection from the client side - ClientClose() - // Subscribe subscribes to the connection - Subscribe() error -} - -func waitForAck(conn net.Conn, readTimeout, writeTimeout time.Duration) error { - timer := time.NewTimer(ackWaitTimeout) for { select { - case <-timer.C: - return fmt.Errorf("timeout while waiting for connection_ack") - default: - } - if err := conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { - return fmt.Errorf("failed to set read deadline: %w", err) - } - msg, err := wsutil.ReadServerText(conn) - if err != nil { - return err - } - respType, err := jsonparser.GetString(msg, "type") - if err != nil { - return err - } + case <-ctx.Done(): + updater.Done() + return - switch respType { - // TODO this method mixes message types from different protocols. We should - // move the specific protocol handling to the concrete implementation - case messageTypeConnectionKeepAlive: - continue - case messageTypePing: - if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return fmt.Errorf("failed to set write deadline: %w", err) - } - err = wsutil.WriteClientText(conn, []byte(pongMessage)) - if err != nil { - return fmt.Errorf("failed to send pong message: %w", err) + case msg, ok := <-msgCh: + if !ok { + updater.Done() + return } - continue - case messageTypeConnectionAck: - return nil - default: - return fmt.Errorf("expected connection_ack or ka, got %s", respType) - } - } -} -type connResult struct { - fd int - shouldClose bool -} - -func (c *subscriptionClient) runNetPoll(ctx context.Context) { - defer c.close() - done := ctx.Done() - // both handleConnCh and connResults are buffered channels with a size of WaitForNumEvents - // this is important because we submit all events before we start processing them - // and we start evaluating the results only after all events have been submitted - // this would not be possible with unbuffered channels - handleConnCh := make(chan *connection, c.netPollConfig.WaitForNumEvents) - connResults := make(chan connResult, c.netPollConfig.WaitForNumEvents) - pingCh := make(chan *connection, c.netPollConfig.WaitForNumEvents) - - // Start workers to handle connection events - // MaxEventWorkers defines the parallelism of how many connections can be handled at the same time - // This is the critical number on how much CPU is used - for i := 0; i < c.netPollConfig.MaxEventWorkers; i++ { - go func() { - for { - select { - case conn := <-pingCh: - conn.handler.Ping() - case conn := <-handleConnCh: - shouldClose := c.handleConnectionEvent(conn) - connResults <- connResult{fd: conn.fd, shouldClose: shouldClose} - case <-done: - return + if msg.Err != nil { + if isConnectionError(msg.Err) { + updater.Error(formatUpstreamServiceError(msg.Err)) + } else { + updater.Error(formatSubscriptionError(msg.Err)) } - } - }() - } - - pingTicker := time.NewTicker(c.pingInterval) - defer pingTicker.Stop() - - // This is the main netPoll run loop - // It's a single threaded event loop that reacts to several events, such as added connections, clients unsubscribing, etc. - for { - select { - // if the engine context is done, we close the netPoll loop - case <-done: - return - case <-pingTicker.C: - // Send a ping to all connections - // We distribute the ping to all workers to prevent single threaded bottlenecks - // However, this required state synchronization with the last ping time on the handler - // because PING and PONG can be handled on different go routines - for _, conn := range c.netPollState.connections { - pingCh <- conn - } - case conn := <-c.netPollState.addConn: - c.handleAddConn(conn) - case id := <-c.netPollState.clientUnsubscribe: - c.handleClientUnsubscribe(id) - // while len(c.connections) == 0, this channel is nil, so we will never try to wait for netPoll events - // this is important to prevent busy waiting - // once we add the first connection, we start the ticker and set the tick channel - // the ticker ensures that we don't poll the netPoll instance all the time, - // but at most every TickInterval - case <-c.netPollState.waitForEventsTick: - events, err := c.netPoll.Wait(c.netPollConfig.WaitForNumEvents) - if err != nil { - c.log.Error("netPoll.Wait", abstractlogger.Error(err)) - continue + updater.Done() + return } - waitForEvents := len(events) - - for i := range events { - fd := netpoll.SocketFD(events[i]) - conn, ok := c.netPollState.connections[fd] - if !ok { - // This can happen if the client was unsubscribed - // and the ticker is still running because we haven't removed the last connection yet - continue + if msg.Payload != nil { + data, err := json.Marshal(msg.Payload) + if err != nil { + updater.Error(formatSubscriptionError(err)) + updater.Done() + return } - // submit the connection to the worker pool - handleConnCh <- conn - } - - // we submit all events to the worker pool to handle all events in parallel - // instead of just waiting until all handlers are done, we can handle newly added connections or clients unsubscribing simultaneously - // we keep doing this until we have results for all events or the engine context is done - // this allows us to keep handling events in parallel while being able to manage connections without locks - // as a result, we can handle a large number of connections with a single threaded event loop - - // once we have results for all events, we can return to the top level loop and wait for the next tick - for waitForEvents > 0 { - select { - case result := <-connResults: - // if the connection indicates that it should be closed, we close and remove it - if result.shouldClose { - c.handleServerUnsubscribe(result.fd) - } - // we decrease the number of events we're waiting for to eventually break the loop - waitForEvents-- - case conn := <-c.netPollState.addConn: - c.handleAddConn(conn) - case id := <-c.netPollState.clientUnsubscribe: - c.handleClientUnsubscribe(id) - case <-done: + if msg.Done { + updater.Error(data) + updater.Done() return } + updater.Update(data) + } + + if msg.Done { + updater.Complete() + updater.Done() + return } } } } -func (c *subscriptionClient) close() { - defer c.log.Debug("subscriptionClient.close", abstractlogger.String("reason", "netPoll closed by context")) - if c.netPollState.waitForEventsTicker != nil { - c.netPollState.waitForEventsTicker.Stop() - } - for _, conn := range c.netPollState.connections { - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ServerClose() - } - if c.netPoll != nil { - err := c.netPoll.Close(false) - if err != nil { - c.log.Error("subscriptionClient.close", abstractlogger.Error(err)) - } - } +func isConnectionError(err error) bool { + return errors.Is(err, client.ErrConnectionClosed) || + errors.Is(err, client.ErrConnectionError) || + errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) } -func (c *subscriptionClient) handleAddConn(conn *connection) { - var netConn net.Conn +// convertToClientOptions converts GraphQLSubscriptionOptions to the new client's types. +func convertToClientOptions(options GraphQLSubscriptionOptions) (client.Options, *client.Request, error) { + opts := client.Options{ + Endpoint: options.URL, + Headers: options.Header, + } - if tlsConn, ok := conn.netConn.(*tls.Conn); ok { - netConn = tlsConn.NetConn() + // Transport selection + if options.UseSSE { + opts.Transport = client.TransportSSE + if options.SSEMethodPost { + opts.SSEMethod = client.SSEMethodPOST + } else { + opts.SSEMethod = client.SSEMethodGET + } } else { - netConn = conn.netConn + opts.Transport = client.TransportWS + opts.WSSubprotocol = mapWSSubprotocol(options.WsSubProtocol) } - if err := c.netPoll.Add(netConn); err != nil { - c.log.Error("subscriptionClient.handleAddConn", abstractlogger.Error(err)) - conn.handler.ServerClose() - return + // Convert InitialPayload from json.RawMessage to map[string]any + if len(options.InitialPayload) > 0 { + var initPayload map[string]any + if err := json.Unmarshal(options.InitialPayload, &initPayload); err != nil { + return client.Options{}, nil, fmt.Errorf("failed to unmarshal initial payload: %w", err) + } + opts.InitPayload = initPayload } - c.netPollState.connections[conn.fd] = conn - c.netPollState.triggers[conn.id] = conn.fd - // when we previously had 0 connections, we will have 1 connection now - // this means we need to start the ticker so that we get netPoll events - if len(c.netPollState.connections) == 1 { - c.netPollState.waitForEventsTicker = time.NewTicker(c.netPollConfig.TickInterval) - c.netPollState.waitForEventsTick = c.netPollState.waitForEventsTicker.C - c.netPollState.hasConnections.Store(true) + req := &client.Request{ + Query: options.Body.Query, + OperationName: options.Body.OperationName, + Variables: options.Body.Variables, + Extensions: options.Body.Extensions, } -} -func (c *subscriptionClient) handleClientUnsubscribe(id uint64) { - fd, ok := c.netPollState.triggers[id] - if !ok { - return - } - delete(c.netPollState.triggers, id) - conn, ok := c.netPollState.connections[fd] - if !ok { - return - } - delete(c.netPollState.connections, fd) - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ClientClose() - // if we have no connections left, we stop the ticker - if len(c.netPollState.connections) == 0 { - c.netPollState.waitForEventsTicker.Stop() - c.netPollState.waitForEventsTick = nil - c.netPollState.hasConnections.Store(false) - } + return opts, req, nil } -func (c *subscriptionClient) handleServerUnsubscribe(fd int) { - conn, ok := c.netPollState.connections[fd] - if !ok { - return - } - delete(c.netPollState.connections, fd) - delete(c.netPollState.triggers, conn.id) - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ServerClose() - // if we have no connections left, we stop the ticker - if len(c.netPollState.connections) == 0 { - c.netPollState.waitForEventsTicker.Stop() - c.netPollState.waitForEventsTick = nil - c.netPollState.hasConnections.Store(false) - } -} - -func (c *subscriptionClient) handleConnectionEvent(conn *connection) bool { - data, err := readMessage(conn.netConn, c.frameTimeout, c.readTimeout) - if err != nil { - return handleConnectionError(err) +// mapWSSubprotocol maps the string subprotocol to the client.WSSubprotocol type. +func mapWSSubprotocol(proto string) client.WSSubprotocol { + switch proto { + case "graphql-ws": + return client.SubprotocolGraphQLWS + case "graphql-transport-ws": + return client.SubprotocolGraphQLTransportWS + default: + return client.SubprotocolAuto } - return conn.handler.HandleMessage(data) } -func handleConnectionError(err error) (done bool) { - netOpErr := &net.OpError{} - if errors.As(err, &netOpErr) { - return !netOpErr.Timeout() +// formatUpstreamServiceError formats a connection-level error as a GraphQL error +// response with the UPSTREAM_SERVICE_ERROR extension code. If the error is a +// WebSocket close error, the close code and reason are included in extensions. +func formatUpstreamServiceError(err error) []byte { + type errorExtensions struct { + Code string `json:"code"` + CloseCode int `json:"closeCode,omitempty"` + Reason string `json:"closeReason,omitempty"` } - // Check if we have errors during reading from the connection - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return true + type graphqlError struct { + Message string `json:"message"` + Extensions errorExtensions `json:"extensions"` } - // Check if we have a context error - if errors.Is(err, context.DeadlineExceeded) { - return false - } - - // Check if the error is a connection reset by peer - if errors.Is(err, syscall.ECONNRESET) { - return true - } - if errors.Is(err, syscall.EPIPE) { - return true - } + ext := errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"} - // Check if the error is a closed network connection. Introduced in go 1.16. - // This replaces the string match of "use of closed network connection" - if errors.Is(err, net.ErrClosed) { - return true + var closeErr websocket.CloseError + if errors.As(err, &closeErr) { + ext.CloseCode = int(closeErr.Code) + ext.Reason = closeErr.Reason } - // Check if the error is closed websocket connection - if errors.As(err, &wsutil.ClosedError{}) { - return true + resp := struct { + Errors []graphqlError `json:"errors"` + }{ + Errors: []graphqlError{{Message: "upstream service closed the connection", Extensions: ext}}, } - - return false + data, _ := json.Marshal(resp) + return data } -// readMessage reads a message from the connection -func readMessage(conn net.Conn, frameTimeout time.Duration, readTimeout time.Duration) ([]byte, error) { - controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide) - rd := &wsutil.Reader{ - Source: conn, - State: ws.StateClientSide, - CheckUTF8: true, - SkipHeaderCheck: false, - OnIntermediate: controlHandler, +// formatSubscriptionError formats an error as a GraphQL error response. +func formatSubscriptionError(err error) []byte { + errResponse := struct { + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + }{ + Errors: []struct { + Message string `json:"message"` + }{ + {Message: err.Error()}, + }, } + data, _ := json.Marshal(errResponse) + return data +} - for { - // This method is used to check if we have data on the connection. The timeout needs to be much smaller - // than the readTimeout to ensure we don't block the connection for too long. If we have no data, we move - // on to the next connection. - err := conn.SetReadDeadline(time.Now().Add(frameTimeout)) - if err != nil { - return nil, err - } +// GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. +// This can be very handy for testing purposes. +type GraphQLSubscriptionClientFactory interface { + NewSubscriptionClient(ctx context.Context, options ...SubscriptionClientOption) GraphQLSubscriptionClient +} - // If we have data, we can read it. Otherwise, it will timeout and we wait for the next epoll tick - hdr, err := rd.NextFrame() - if err != nil { - // A timeout will not close the connection but return an error - return nil, err - } - if hdr.OpCode.IsControl() { - // The controlHandler writes the control frames. - // We need to work with a proper timeout to ensure we don't block forever. - err := conn.SetWriteDeadline(time.Now().Add(frameTimeout)) - if err != nil { - return nil, err - } - // Handles PING/PONG and CLOSE frames, but only on the ws protocol level - // We still need to handle the PING/PONG frames on the application protocol level - if err := controlHandler(hdr, rd); err != nil { - return nil, err - } - continue - } +type DefaultSubscriptionClientFactory struct{} - // We are only interested in text frames - if hdr.OpCode&ws.OpText == 0 { - // If we see anything else than a text frame, we need to discard the frame - if err := rd.Discard(); err != nil { - return nil, err - } - continue - } +func (d *DefaultSubscriptionClientFactory) NewSubscriptionClient(ctx context.Context, options ...SubscriptionClientOption) GraphQLSubscriptionClient { + return NewGraphQLSubscriptionClient(ctx, options...) +} - // We limit the amount of time we wait for a message to be read from the connection - // This is important to ensure we don't block the connection for too long - err = conn.SetReadDeadline(time.Now().Add(readTimeout)) - if err != nil { - return nil, err - } - return io.ReadAll(rd) - } +func IsDefaultGraphQLSubscriptionClient(client GraphQLSubscriptionClient) bool { + _, ok := client.(*subscriptionClientV2) + return ok } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 86dd57c030..956cd80b73 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2,2570 +2,114 @@ package graphql_datasource import ( "context" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "runtime" - "strings" - "sync" + "errors" "testing" - "time" - "github.com/buger/jsonparser" - "github.com/coder/websocket" - ll "github.com/jensneuse/abstractlogger" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" - "go.uber.org/goleak" - "go.uber.org/zap" + client "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -func logger() ll.Logger { - logger, err := zap.NewDevelopmentConfig().Build() - if err != nil { - panic(err) - } - - return ll.NewZapLogger(logger, ll.DebugLevel) -} - -func TestGetConnectionInitMessageHelper(t *testing.T) { - var callback OnWsConnectionInitCallback = func(ctx context.Context, url string, header http.Header) (json.RawMessage, error) { - return json.RawMessage(`{"authorization":"secret"}`), nil - } - - tests := []struct { - name string - callback *OnWsConnectionInitCallback - want string - }{ - { - name: "without payload", - callback: nil, - want: `{"type":"connection_init"}`, - }, - { - name: "with payload", - callback: &callback, - want: `{"type":"connection_init","payload":{"authorization":"secret"}}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := subscriptionClient{onWsConnectionInitCallback: tt.callback} - got, err := client.getConnectionInitMessage(context.Background(), "", nil) - require.NoError(t, err) - require.NotEmpty(t, got) - - assert.Equal(t, tt.want, string(got)) - }) - } +type testBridgeUpdater struct { + updates [][]byte + errors [][]byte + completed bool + done bool } -func TestWebsocketSubscriptionClientImmediateClientCancel(t *testing.T) { - serverInvocations := atomic.NewInt64(0) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serverInvocations.Inc() - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.Error(t, err) - }() - assert.Eventuallyf(t, func() bool { - return serverInvocations.Load() == 0 - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - - _, _, err = conn.Read(ctx) - assert.Error(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - serverCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -// didnt configure subprotocol, but the subgraph return graphql-ws -func TestSubprotocolNegotiationWithGraphQLWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-ws"}, - }) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Update(data []byte) { + t.updates = append(t.updates, data) } -// didnt configure subprotocol, but the subgraph return graphql-transport-ws -func TestSubprotocolNegotiationWithGraphQLTransportWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) +func (t *testBridgeUpdater) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) {} - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Complete() { + t.completed = true } -// In this case the subgraph doesnt return the subprotocol and we didnt configure the subprotocol, so falls back to graphql-ws -func TestSubprotocolNegotiationWithNoSubprotocol(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Error(data []byte) { + t.errors = append(t.errors, data) } -func TestSubprotocolNegotiationWithConfiguredGraphQLWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Done() { + t.done = true } -func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) CloseSubscription(id resolve.SubscriptionIdentifier) { } -func TestWebSocketClientLeaks(t *testing.T) { - defer goleak.VerifyNone(t, - goleak.IgnoreCurrent(), // ignore the test itself - ) - serverDone := &sync.WaitGroup{} - serverDone.Add(2) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - serverDone.Done() - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - wg := &sync.WaitGroup{} - wg.Add(2) - for i := 0; i < 2; i++ { - go func(i int) { - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - updater := &testSubscriptionUpdater{} - err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(uint64(i)) - clientCancel() - wg.Done() - }(i) - } - wg.Wait() - time.Sleep(time.Second) - serverCancel() - time.Sleep(time.Second) - serverDone.Wait() +func (t *testBridgeUpdater) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { + return map[context.Context]resolve.SubscriptionIdentifier{} } -func TestAsyncSubscribe(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.SkipNow() - } - t.Run("subscribe async", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - defer close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() +func TestReadLoopErrorHandling(t *testing.T) { + t.Run("connection errors deliver error and done without updates", func(t *testing.T) { + updater := &testBridgeUpdater{} + msgCh := make(chan *client.Message, 1) + msgCh <- &client.Message{Err: client.ErrConnectionClosed} + close(msgCh) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + subClient := &subscriptionClientV2{} + subClient.readLoop(context.Background(), msgCh, func() {}, updater) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() + require.True(t, updater.done) + require.Len(t, updater.errors, 1) + require.Len(t, updater.updates, 0) + require.False(t, updater.completed) }) - t.Run("server timeout", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 2) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() + t.Run("non-connection errors deliver error and done without updates", func(t *testing.T) { + updater := &testBridgeUpdater{} + msgCh := make(chan *client.Message, 1) + msgCh <- &client.Message{Err: errors.New("validation failed")} + close(msgCh) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + subClient := &subscriptionClientV2{} + subClient.readLoop(context.Background(), msgCh, func() {}, updater) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() + require.True(t, updater.done) + require.Len(t, updater.errors, 1) + require.Len(t, updater.updates, 0) + require.False(t, updater.completed) }) - t.Run("server complete", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) - assert.NoError(t, err) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + t.Run("context cancellation calls done without complete", func(t *testing.T) { + updater := &testBridgeUpdater{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + msgCh := make(chan *client.Message) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) + subClient := &subscriptionClientV2{} + subClient.readLoop(ctx, msgCh, func() {}, updater) - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() + require.True(t, updater.done) + require.False(t, updater.completed) }) - t.Run("server ka", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) + t.Run("channel close calls done without complete", func(t *testing.T) { + updater := &testBridgeUpdater{} + msgCh := make(chan *client.Message) + close(msgCh) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) - assert.NoError(t, err) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() + subClient := &subscriptionClientV2{} + subClient.readLoop(context.Background(), msgCh, func() {}, updater) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() + require.True(t, updater.done) + require.False(t, updater.completed) }) - t.Run("long timeout", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 2) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + t.Run("done message calls complete then done", func(t *testing.T) { + updater := &testBridgeUpdater{} + msgCh := make(chan *client.Message, 1) + msgCh <- &client.Message{Done: true} + close(msgCh) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) + subClient := &subscriptionClientV2{} + subClient.readLoop(context.Background(), msgCh, func() {}, updater) - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - time.Sleep(time.Second) - assert.Equal(t, false, client.netPollState.hasConnections.Load()) + require.True(t, updater.done) + require.True(t, updater.completed) + require.Len(t, updater.errors, 0) }) - t.Run("forever timeout", func(t *testing.T) { - t.Parallel() - globalCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - <-globalCtx.Done() - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*3, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - time.Sleep(time.Second * 2) - }) - t.Run("graphql-ws", func(t *testing.T) { - t.Parallel() - t.Run("happy path", func(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - - ctx = conn.CloseRead(ctx) - <-ctx.Done() - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - - clientCancel() - - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("connection error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"connection error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - assert.Equal(t, false, client.netPollState.hasConnections.Load()) - }) - t.Run("error object", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error array", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - }) - t.Run("graphql-transport-ws", func(t *testing.T) { - t.Parallel() - t.Run("happy path", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - - ctx = conn.CloseRead(ctx) - <-ctx.Done() - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("happy path no netPoll", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: false, - }), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("happy path no netPoll two clients", func(t *testing.T) { - t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: false, - }), - ).(*subscriptionClient) - wg := &sync.WaitGroup{} - wg.Add(2) - for i := 0; i < 2; i++ { - go func(i int) { - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - updater := &testSubscriptionUpdater{} - err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - wg.Done() - }(i) - } - wg.Wait() - }) - t.Run("ping", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"pong"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("ka", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error object", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error array", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("data error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("connection error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - }) -} - -func TestClientToSubgraphPingPong(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows as it's not reliable") - } - - t.Run("client sends ping message after configured interval", func(t *testing.T) { - t.Parallel() - - serverDone := make(chan struct{}) // to signal server done - // buffered channels and non-blocking send to avoid double-close panics if events repeat - pingReceived := make(chan struct{}, 1) // signaled when the server receives a ping - payloadSend := make(chan struct{}, 1) // signaled when the server sends a payload - - // Create test server that will handle the WebSocket connection - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - // Send initial data - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`)) - assert.NoError(t, err) - - // Track what messages we've received - receivedPing := false - receivedComplete := false - - // Create a context with timeout for reading messages - readCtx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - - // Process up to 5 messages or until we see both a ping and a complete - for i := 0; i < 5; i++ { - if receivedPing && receivedComplete { - break - } - - _, data, err = conn.Read(readCtx) - if err != nil { - // Connection closed or timeout - t.Logf("Connection read ended: %v", err) - break - } - - messageStr := string(data) - t.Logf("Received message: %s", messageStr) - - switch messageStr { - case `{"type":"ping"}`: - receivedPing = true - select { - case pingReceived <- struct{}{}: - default: - } - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"pong"}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"after ping-pong"}}}}`)) - assert.NoError(t, err) - select { - case payloadSend <- struct{}{}: - default: - } - case `{"id":"1","type":"complete"}`: - receivedComplete = true - } - } - - // Test is successful if we received a ping message - if !receivedPing { - t.Error("Did not receive ping message from client") - } - })) - defer server.Close() - - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - // Create subscription client with a short ping interval for testing - pingInterval := 400 * time.Millisecond - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithPingInterval(pingInterval), - WithPingTimeout(1*time.Second), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - // Wait for ping to be received before unsubscribing - select { - case <-pingReceived: - t.Log("Ping received successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for ping, will unsubscribe anyway") - } - // don't unsubscribe immediately, give the server time to send a payload - select { - case <-payloadSend: - t.Log("Payload sent successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for sent payload, will unsubscribe anyway") - } - - // Cleanup - client.Unsubscribe(1) - - // Wait for server to finish - select { - case <-serverDone: - // Server completed successfully - case <-time.After(5 * time.Second): - t.Fatal("Timed out waiting for server to complete") - } - - clientCancel() - serverCancel() - }) - - t.Run("client responds with pong when server sends ping", func(t *testing.T) { - t.Parallel() - - pongReceived := make(chan struct{}) - serverDone := make(chan struct{}) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - // Send initial data - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`)) - assert.NoError(t, err) - - // Send ping message - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - // Wait for pong response from client - msgType, data, err = conn.Read(ctx) - if err != nil { - t.Errorf("Error reading pong: %v", err) - return - } - - if string(data) == `{"type":"pong"}` { - assert.Equal(t, websocket.MessageText, msgType) - // Send another data message - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"after ping-pong"}}}}`)) - assert.NoError(t, err) - - close(pongReceived) - } - - // Wait for client to unsubscribe (complete message) - readTimeout := time.NewTimer(3 * time.Second) - defer readTimeout.Stop() - - readDone := make(chan struct{}) - go func() { - msgType, data, _ = conn.Read(ctx) - close(readDone) - }() - - select { - case <-readDone: - // Successfully read client message - case <-readTimeout.C: - // Timeout is fine, we're just waiting for unsubscribe - } - - close(serverDone) - })) - defer server.Close() - - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithReadTimeout(100*time.Millisecond), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 2, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - select { - case <-pongReceived: - t.Log("Server received pong successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for pong in server, will unsubscribe anyway") - } - - // Verify we receive at least the initial data - updater.mux.Lock() - updatesCount := len(updater.updates) - firstUpdate := "" - if updatesCount > 0 { - firstUpdate = updater.updates[0] - } - updater.mux.Unlock() - - assert.GreaterOrEqual(t, updatesCount, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"initial data"}}}`, firstUpdate) - - // Cleanup - client.Unsubscribe(2) - t.Log("client unsubscribed") - - // Wait for server to finish - select { - case <-serverDone: - // Server completed successfully - case <-time.After(2 * time.Second): - t.Fatal("Timed out waiting for server to complete") - } - - clientCancel() - serverCancel() - }) -} - -func TestClientClosesConnectionOnPingTimeout(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows as it's not reliable") - } - - t.Parallel() - - serverDone := make(chan struct{}) - pingReceived := make(chan struct{}, 1) // Buffer 1 in case ping arrives slightly late - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "server done") - close(serverDone) - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - // Use regexp because client might generate different IDs - assert.Regexp(t, `{"id":".*","type":"subscribe","payload":{.*}}`, string(data)) - // More specific check for query - payloadQuery, err := jsonparser.GetString(data, "payload", "query") - if assert.NoError(t, err) { - assert.Equal(t, `subscription {messageAdded(roomName: "room"){text}}`, payloadQuery) - } - - // Send initial data - subID, err := jsonparser.GetString(data, "id") // Get the actual ID used by the client - if !assert.NoError(t, err) { - return - } - initialDataMsg := fmt.Sprintf(`{"id":"%s","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`, subID) - err = conn.Write(r.Context(), websocket.MessageText, []byte(initialDataMsg)) - assert.NoError(t, err) - - // Wait for ping, but DO NOT send pong - readCtx, cancelRead := context.WithTimeout(ctx, 5*time.Second) // Timeout for reading messages - defer cancelRead() - - hasReceivedPing := false - for !hasReceivedPing { - _, data, err = conn.Read(readCtx) - if err != nil { - t.Logf("Server read error (expected after client closes): %v", err) - // Expecting an error here eventually as the client should close the connection - assert.Error(t, err, "Server should encounter read error when client closes connection due to ping timeout") - // Signal that the server is done (connection closed) - close(serverDone) - return // Exit handler goroutine - } - - messageStr := string(data) - t.Logf("Server received: %s", messageStr) - if messageStr == `{"type":"ping"}` { - t.Log("Server received ping, NOT sending pong.") - hasReceivedPing = true - select { - case pingReceived <- struct{}{}: - default: // Avoid blocking if channel is full - } - } else if strings.Contains(messageStr, `"type":"complete"`) { - // Client might send complete before closing if test runs fast - t.Log("Server received complete from client.") - } else { - t.Logf("Server received unexpected message type: %s", messageStr) - } - } - - // Keep reading until the connection is closed by the client - for { - // Use a timeout context to make sure we don't hang indefinitely - readCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - _, data, err = conn.Read(readCtx) - cancel() - - if err != nil { - t.Logf("Server read error after ping (expected): %v", err) - assert.Error(t, err, "Server should encounter read error after client closes connection") - return // Exit handler goroutine - } - - // Log any messages received before connection close - t.Logf("Server still receiving messages: %s", string(data)) - } - })) - defer server.Close() - - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithPingInterval(500*time.Millisecond), - WithPingTimeout(100*time.Millisecond), - // Need netpoll enabled for ping/pong handling - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - // Use a unique ID for async subscription - subscriptionID := uint64(42) - - err := client.SubscribeAsync(resolve.NewContext(ctx), subscriptionID, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - // Wait for initial data - updater.AwaitUpdates(t, 3*time.Second, 1) - updater.mux.Lock() - updatesCount := len(updater.updates) - firstUpdate := "" - if updatesCount > 0 { - firstUpdate = updater.updates[0] - } - updater.mux.Unlock() - - require.Equal(t, 1, updatesCount, "Client should receive initial data") - assert.Equal(t, `{"data":{"messageAdded":{"text":"initial data"}}}`, firstUpdate) - - // Wait for the server to confirm it received a ping - select { - case <-pingReceived: - t.Log("Test confirmed server received ping.") - case <-time.After(3 * time.Second): // Should receive ping within ~pingInterval + read time - t.Fatal("Timed out waiting for server to receive ping") - } - - // Wait for server to signal it's done (connection closed by client) - select { - case <-serverDone: - t.Log("Server confirmed connection closed.") - // Success: server detected connection closure - case <-time.After(5 * time.Second): // Should happen within ~2*pingInterval + processing time - t.Fatal("Timed out waiting for server to detect connection closure") - } - - // Explicitly unsubscribe just in case, although it should be closed already - client.Unsubscribe(subscriptionID) - clientCancel() // Cancel client context - serverCancel() // Cancel server context (though serverDone should ensure it exited) -} - -func TestWebSocketUpgradeFailures(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - statusCode int - headers map[string]string - expectError bool - errorContains string - }{ - { - name: "HTTP 400 Bad Request", - statusCode: http.StatusBadRequest, - headers: map[string]string{"Content-Type": "text/plain"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 401 Unauthorized", - statusCode: http.StatusUnauthorized, - headers: map[string]string{"WWW-Authenticate": "Bearer"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 403 Forbidden", - statusCode: http.StatusForbidden, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 404 Not Found", - statusCode: http.StatusNotFound, - headers: map[string]string{"Content-Type": "text/html"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 500 Internal Server Error", - statusCode: http.StatusInternalServerError, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 502 Bad Gateway", - statusCode: http.StatusBadGateway, - headers: map[string]string{"Content-Type": "text/html"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 503 Service Unavailable", - statusCode: http.StatusServiceUnavailable, - headers: map[string]string{"Retry-After": "60"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 200 OK (wrong status for WebSocket)", - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for key, value := range tc.headers { - w.Header().Set(key, value) - } - w.WriteHeader(tc.statusCode) - _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) - })) - defer server.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - - wsURL := strings.Replace(server.URL, "http://", "ws://", 1) - - updater := &testSubscriptionUpdater{} - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: wsURL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - - if tc.expectError { - require.ErrorContains(t, err, tc.errorContains) - - // Verify the error is of the correct type - var upgradeErr *UpgradeRequestError - require.ErrorAs(t, err, &upgradeErr) - require.Equal(t, tc.statusCode, upgradeErr.StatusCode) - require.Equal(t, server.URL, upgradeErr.URL) - } else { - assert.NoError(t, err, "Expected no error for status code %d", tc.statusCode) - } - }) - } -} - -func TestInvalidWebSocketAcceptKey(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - acceptKeyHandler func(challengeKey string) string - expectError bool - errorContains string - }{ - { - name: "Missing Sec-WebSocket-Accept header", - acceptKeyHandler: func(challengeKey string) string { - return "" // Don't set the header - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - { - name: "Malformed base64 Sec-WebSocket-Accept", - acceptKeyHandler: func(challengeKey string) string { - return "not-valid-base64!!!" - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - { - name: "Correct length but wrong content", - acceptKeyHandler: func(challengeKey string) string { - // 20 bytes (not the SHA-1 of challengeKey+GUID) - return base64.StdEncoding.EncodeToString([]byte("12345678901234567890")) - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - var receivedChallengeKey string - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedChallengeKey = r.Header.Get("Sec-WebSocket-Key") - require.NotEmpty(t, receivedChallengeKey, "Challenge key should be present in request") - - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Sec-WebSocket-Version", "13") - - acceptKey := tc.acceptKeyHandler(receivedChallengeKey) - if acceptKey != "" { - w.Header().Set("Sec-WebSocket-Accept", acceptKey) - } - // If acceptKey is empty, we don't set the header (simulating missing header) - - w.WriteHeader(http.StatusSwitchingProtocols) - - // Close the connection immediately to prevent hanging - // This simulates a server that sends 101 but then closes - if hijacker, ok := w.(http.Hijacker); ok { - conn, _, err := hijacker.Hijack() - if err == nil { - conn.Close() - } - } - })) - defer server.Close() - - // Create subscription client - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - - wsURL := strings.Replace(server.URL, "http://", "ws://", 1) - - updater := &testSubscriptionUpdater{} - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: wsURL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - - require.Error(t, err) - require.ErrorContains(t, err, tc.errorContains) - require.NotEmpty(t, receivedChallengeKey) - }) - } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go deleted file mode 100644 index 8ebf5447cd..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ /dev/null @@ -1,379 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "sync/atomic" - "time" - - "github.com/buger/jsonparser" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -// gqlTWSConnectionHandler is responsible for handling a connection to an origin -// it is responsible for managing all subscriptions using the underlying WebSocket connection -// if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate -type gqlTWSConnectionHandler struct { - // The underlying net.Conn. Only used for netPoll. Should not be used to shutdown the connection. - conn net.Conn - requestContext, engineContext context.Context - log abstractlogger.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater - lastPingSentUnix atomic.Int64 - pingTimeout time.Duration - shuttingDown atomic.Bool -} - -func (h *gqlTWSConnectionHandler) ServerClose() { - h.shuttingDown.Store(true) - - // Because the server closes the connection, we need to send a close frame to the event loop. - h.updater.Close(resolve.SubscriptionCloseKindDownstreamServiceError) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -// ClientClose is called when the client closes the connection. Is called when the trigger is shutdown with all subscriptions. -func (h *gqlTWSConnectionHandler) ClientClose() { - h.shuttingDown.Store(true) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(`{"id":"1","type":"complete"}`)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() - -} - -func (h *gqlTWSConnectionHandler) Subscribe() error { - return h.subscribe() -} - -func (h *gqlTWSConnectionHandler) HandleMessage(data []byte) (done bool) { - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - return false - } - switch messageType { - case messageTypePing: - h.handleMessageTypePing() - return false - case messageTypeNext: - h.handleMessageTypeNext(data) - return false - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return true - case messageTypeError: - h.handleMessageTypeError(data) - return false - case messageTypeData, messageTypeConnectionError: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return true - case messageTypePong: - h.handleMessageTypePong() - return false - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - return false - } -} - -func (h *gqlTWSConnectionHandler) NetConn() net.Conn { - return h.conn -} - -func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l abstractlogger.Logger) *connection { - handler := &gqlTWSConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: l, - updater: updater, - options: options, - pingTimeout: options.pingTimeout, - } - - return &connection{ - handler: handler, - netConn: conn, - } -} - -func (h *gqlTWSConnectionHandler) StartBlocking() error { - readCtx, cancel := context.WithCancel(h.requestContext) - dataCh := make(chan []byte) - errCh := make(chan error) - - defer func() { - cancel() - h.unsubscribeAllAndCloseConn() - }() - - err := h.subscribe() - if err != nil { - return err - } - - pingTicker := time.NewTicker(h.options.pingInterval) - defer pingTicker.Stop() - - go h.readBlocking(readCtx, h.options.readTimeout, dataCh, errCh) - - for { - select { - case <-h.engineContext.Done(): - return h.engineContext.Err() - case <-readCtx.Done(): - return readCtx.Err() - case <-pingTicker.C: - h.Ping() - case err := <-errCh: - h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) - h.broadcastErrorMessage(err) - return err - case data := <-dataCh: - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - continue - } - - switch messageType { - case messageTypePong: - h.handleMessageTypePong() - continue - case messageTypePing: - h.handleMessageTypePing() - continue - case messageTypeNext: - h.handleMessageTypeNext(data) - continue - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return nil - case messageTypeError: - h.handleMessageTypeError(data) - continue - case messageTypeConnectionKeepAlive: - continue - case messageTypeData, messageTypeConnectionError: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return errors.New("invalid subprotocol") - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - continue - } - } - } -} - -func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { - h.unsubscribe() - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlTWSConnectionHandler) unsubscribe() { - h.updater.Complete() - req := fmt.Sprintf(completeMessage, "1") - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(req)) - if err != nil { - h.log.Error("failed to write complete message", abstractlogger.Error(err)) - } -} - -// subscribe adds a new Subscription to the gqlTWSConnectionHandler and sends the subscribeMessage to the origin -func (h *gqlTWSConnectionHandler) subscribe() error { - graphQLBody, err := json.Marshal(h.options.Body) - if err != nil { - return err - } - subscribeRequest := fmt.Sprintf(subscribeMessage, "1", string(graphQLBody)) - if err = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return err - } - if err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)); err != nil { - return err - } - return nil -} - -func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) { - errMsg := fmt.Sprintf(errorMessageTemplate, err) - h.updater.Update([]byte(errMsg)) -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - h.updater.Complete() -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, valueType, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.log.Error( - "failed to get payload from error message", - abstractlogger.Error(err), - abstractlogger.ByteString("raw message", data), - ) - h.updater.Update([]byte(internalError)) - return - } - - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, value, "errors") - if err != nil { - h.log.Error( - "failed to set errors response", - abstractlogger.Error(err), - abstractlogger.ByteString("raw message", value), - ) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, value, "errors", "[0]") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - default: - h.updater.Update([]byte(internalError)) - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypePing() { - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(pongMessage)) - if err != nil { - h.log.Error("failed to write pong message", abstractlogger.Error(err)) - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypePong() { - // We received a pong message from the server. We can reset it. - h.lastPingSentUnix.Store(0) -} - -func (h *gqlTWSConnectionHandler) Ping() { - - // Do nothing if the connection is shutting down - if h.shuttingDown.Load() { - h.log.Debug("ping skipped. connection is shutting down") - return - } - - lastPingTimestamp := h.lastPingSentUnix.Load() - - // We will detect a dead connection not immediately but on the next ping interval. - if lastPingTimestamp > 0 { - pingTime := time.Unix(0, lastPingTimestamp) - duration := time.Since(pingTime) - - if duration > h.pingTimeout { - h.log.Error("ping timeout exceeded. Closing connection") - // We close the connection because the ping timeout has been exceeded, - // and we assume the connection is dead. ServerClose will send a done event to the client - // event loop to close all triggers and subscriptions - h.ServerClose() - } - - // We don't want to send another ping if one is already in flight - return - } - - // Start measuring the time since to write the message to the connection, including the IO time - h.lastPingSentUnix.Store(time.Now().UnixNano()) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(pingMessage)) - - if err != nil { - h.log.Error("failed to write ping message", abstractlogger.Error(err)) - return - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, _, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.log.Error( - "failed to get payload from next message", - abstractlogger.Error(err), - ) - h.updater.Update([]byte(internalError)) - return - } - - h.updater.Update(value) -} - -// readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout -// we'll block forever on reading until the context of the gqlTWSConnectionHandler stops -func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, readTimeout time.Duration, dataCh chan []byte, errCh chan error) { - netOpErr := &net.OpError{} - for { - _ = h.conn.SetReadDeadline(time.Now().Add(readTimeout)) - data, err := wsutil.ReadServerText(h.conn) - if err != nil { - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - select { - case <-ctx.Done(): - return - default: - continue - } - } - } - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } - select { - case dataCh <- data: - case <-ctx.Done(): - return - } - } -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go deleted file mode 100644 index cc69902198..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ /dev/null @@ -1,312 +0,0 @@ -//go:build !race - -package graphql_datasource - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/coder/websocket" - "github.com/stretchr/testify/assert" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - go func() { - rCtx := resolve.NewContext(ctx) - err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientPing_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"pong"}`, string(data)) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientError_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"wrongQuery {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"payload":[{"message":"Unexpected Name \"wrongQuery\"","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}],"id":"1","type":"error"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1",type":"complete"}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `wrongQuery {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"Unexpected Name \"wrongQuery\"","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]}`, updater.updates[0]) - - clientCancel() - updater.AwaitDone(t, time.Second) - - serverCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - <-serverCtx.Done() - })) - - // Wrap the listener to hijack the underlying TCP connection. - // Hijacking via http.ResponseWriter doesn't work because the WebSocket - // client already hijacks the connection before us. - wrappedListener := &listenerWrapper{ - listener: server.Listener, - } - server.Listener = wrappedListener - server.Start() - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - // Kill the upstream here. We should get an End-of-File error. - assert.NoError(t, wrappedListener.underlyingConnection.Close()) - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) - - clientCancel() - serverCancel() -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go deleted file mode 100644 index 423080e2d7..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ /dev/null @@ -1,301 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/buger/jsonparser" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -// gqlWSConnectionHandler is responsible for handling a connection to an origin -// it is responsible for managing all subscriptions using the underlying WebSocket connection -// if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate -type gqlWSConnectionHandler struct { - // The underlying net.Conn. Only used for netPoll. Should not be used to shutdown the connection. - conn net.Conn - requestContext, engineContext context.Context - log abstractlogger.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater -} - -func (h *gqlWSConnectionHandler) ServerClose() { - // Because the server closes the connection, we need to send a close frame to the event loop. - h.updater.Close(resolve.SubscriptionCloseKindDownstreamServiceError) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -// ClientClose is called when the client closes the connection. Is called when the trigger is shutdown with all subscriptions. -func (h *gqlWSConnectionHandler) ClientClose() { - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(`{"type":"stop","id":"1"}`)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlWSConnectionHandler) Subscribe() error { - return h.subscribe() -} - -func (h *gqlWSConnectionHandler) HandleMessage(data []byte) (done bool) { - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - return false - } - switch messageType { - case messageTypeConnectionKeepAlive: - return false - case messageTypeData: - h.handleMessageTypeData(data) - return false - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return true - case messageTypeConnectionError: - h.handleMessageTypeConnectionError() - return true - case messageTypeError: - h.handleMessageTypeError(data) - return false - default: - return false - } -} - -func (h *gqlWSConnectionHandler) NetConn() net.Conn { - return h.conn -} - -func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *connection { - handler := &gqlWSConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: log, - updater: updater, - options: options, - } - return &connection{ - handler: handler, - netConn: conn, - } -} - -// StartBlocking starts the single threaded event loop of the handler -// if the global context returns or the websocket connection is terminated, it will stop -func (h *gqlWSConnectionHandler) StartBlocking() error { - dataCh := make(chan []byte) - errCh := make(chan error) - readCtx, cancel := context.WithCancel(h.requestContext) - - defer func() { - cancel() - h.unsubscribeAllAndCloseConn() - }() - - err := h.subscribe() - if err != nil { - return err - } - - go h.readBlocking(readCtx, h.options.readTimeout, dataCh, errCh) - - for { - select { - case <-h.engineContext.Done(): - return h.engineContext.Err() - case <-readCtx.Done(): - return readCtx.Err() - case err := <-errCh: - if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) - } - h.broadcastErrorMessage(err) - return err - case data := <-dataCh: - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - continue - } - switch messageType { - case messageTypeData: - h.handleMessageTypeData(data) - continue - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return nil - case messageTypeConnectionError: - h.handleMessageTypeConnectionError() - return nil - case messageTypeError: - h.handleMessageTypeError(data) - continue - case messageTypeConnectionKeepAlive: - continue - case messageTypePing, messageTypeNext: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-ws, but currently it is set to graphql-transport-ws") - return errors.New("invalid subprotocol") - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - continue - } - } - } -} - -// readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout -// we'll block forever on reading until the context of the gqlWSConnectionHandler stops -func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, readTimeout time.Duration, dataCh chan []byte, errCh chan error) { - netOpErr := &net.OpError{} - for { - _ = h.conn.SetReadDeadline(time.Now().Add(readTimeout)) - data, err := wsutil.ReadServerText(h.conn) - if err != nil { - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - select { - case <-ctx.Done(): - return - default: - continue - } - } - } - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } - select { - case dataCh <- data: - case <-ctx.Done(): - return - } - } -} - -func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { - h.unsubscribe() - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlWSConnectionHandler) Ping() { - // This protocol has no client side ping/pong mechanism. The server send a ka message to understand - // if the connection is still alive. The client only acknowledges the retrieval of the ka message - // by consuming it in the readBlocking loop. - - // TODO We could check if we receive a ka message in a certain time frame and if not, we could close the connection - // However, because we don't send something to the server, we can't verify if the connection is still healthy and - // responsive from both sides. -} - -// subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin -func (h *gqlWSConnectionHandler) subscribe() error { - graphQLBody, err := json.Marshal(h.options.Body) - if err != nil { - return err - } - startRequest := fmt.Sprintf(startMessage, "1", string(graphQLBody)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err = wsutil.WriteClientText(h.conn, []byte(startRequest)) - if err != nil { - return err - } - return nil -} - -func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - payload, _, _, err := jsonparser.Get(data, "payload") - if err != nil { - return - } - - h.updater.Update(payload) -} - -func (h *gqlWSConnectionHandler) handleMessageTypeConnectionError() { - h.updater.Update([]byte(connectionError)) -} - -func (h *gqlWSConnectionHandler) broadcastErrorMessage(err error) { - errMsg := fmt.Sprintf(errorMessageTemplate, err) - h.updater.Update([]byte(errMsg)) -} - -func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - h.updater.Complete() -} - -func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, valueType, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, value, "errors") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, value, "errors", "[0]") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - default: - h.updater.Update([]byte(internalError)) - } -} - -func (h *gqlWSConnectionHandler) unsubscribe() { - h.updater.Complete() - stopRequest := fmt.Sprintf(stopMessage, "1") - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go deleted file mode 100644 index eddc47253c..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ /dev/null @@ -1,371 +0,0 @@ -//go:build !race - -package graphql_datasource - -import ( - "context" - "net" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/coder/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - assertion := require.New(t) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assertion.NoError(err) - - // write "ka" every second - go func() { - for { - err := conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - if err != nil { - break - } - time.Sleep(time.Second) - } - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assertion.NoError(err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assertion.NoError(err) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assertion.NoError(err) - }() - updater.AwaitUpdates(t, time.Second, 2) - assertion.Equal(`{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assertion.Equal(`{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - clientCancel() - assertion.Eventuallyf(func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - rCtx := resolve.NewContext(ctx) - err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientErrorArray(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomNam: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"error","id":"1","payload":[{"message":"error"},{"message":"error"}]}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - _, _, err = conn.Read(r.Context()) - assert.NotNil(t, err) - close(serverDone) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomNam: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, `{"errors":[{"message":"error"},{"message":"error"}]}`, updater.updates[0]) - clientCancel() - updater.AwaitDone(t, time.Second) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClientErrorObject(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomNam: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"error","id":"1","payload":{"message":"error"}}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - _, _, err = conn.Read(r.Context()) - assert.NotNil(t, err) - close(serverDone) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomNam: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"error"}]}`, updater.updates[0]) - clientCancel() - updater.AwaitDone(t, time.Second) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - <-serverCtx.Done() - })) - - // Wrap the listener to hijack the underlying TCP connection. - // Hijacking via http.ResponseWriter doesn't work because the WebSocket - // client already hijacks the connection before us. - wrappedListener := &listenerWrapper{ - listener: server.Listener, - } - server.Listener = wrappedListener - server.Start() - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - // Start a new GQL subscription and exchange some messages. - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - // Kill the upstream here. We should get an End-of-File error. - assert.NoError(t, wrappedListener.underlyingConnection.Close()) - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) - - serverCancel() - clientCancel() -} - -type listenerWrapper struct { - listener net.Listener - underlyingConnection net.Conn -} - -func (l *listenerWrapper) Accept() (net.Conn, error) { - conn, err := l.listener.Accept() - if err != nil { - return nil, err - } - l.underlyingConnection = conn - return l.underlyingConnection, nil -} - -func (l *listenerWrapper) Close() error { - return l.listener.Close() -} - -func (l *listenerWrapper) Addr() net.Addr { - return l.listener.Addr() -} - -var _ net.Listener = (*listenerWrapper)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go deleted file mode 100644 index 16b7b98d04..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go +++ /dev/null @@ -1,47 +0,0 @@ -package graphql_datasource - -// common -var ( - connectionInitMessage = []byte(`{"type":"connection_init"}`) -) - -const ( - messageTypeConnectionAck = "connection_ack" - messageTypeComplete = "complete" - messageTypeError = "error" -) - -// websocket sub-protocol: -// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md -const ( - ProtocolGraphQLWS = "graphql-ws" - - startMessage = `{"type":"start","id":"%s","payload":%s}` - stopMessage = `{"type":"stop","id":"%s"}` - - messageTypeConnectionKeepAlive = "ka" - messageTypeData = "data" - messageTypeConnectionError = "connection_error" -) - -// websocket sub-protocol: -// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md -const ( - ProtocolGraphQLTWS = "graphql-transport-ws" - - subscribeMessage = `{"id":"%s","type":"subscribe","payload":%s}` - pongMessage = `{"type":"pong"}` - pingMessage = `{"type":"ping"}` - completeMessage = `{"id":"%s","type":"complete"}` - - messageTypePing = "ping" - messageTypePong = "pong" - messageTypeNext = "next" -) - -// internal -const ( - internalError = `{"errors":[{"message":"internal error"}]}` - connectionError = `{"errors":[{"message":"connection error"}]}` - errorMessageTemplate = `{"errors":[{"message":"%s"}]}` -) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go new file mode 100644 index 0000000000..5fa58ee0c3 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -0,0 +1,103 @@ +package client + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" +) + +var ErrClientClosed = errors.New("client closed") + +type Client struct { + ctx context.Context + log abstractlogger.Logger + + ws *transport.WSTransport + sse *transport.SSETransport +} + +// Stats contains client statistics. +type Stats struct { + WSConns int // active WebSocket connections + SSEConns int // active SSE connections +} + +// Config holds the client configuration. +type Config struct { + UpgradeClient *http.Client + StreamingClient *http.Client + Logger abstractlogger.Logger + PingInterval time.Duration + PingTimeout time.Duration + AckTimeout time.Duration + WriteTimeout time.Duration + ReadLimit int64 +} + +// New creates a new subscription client with the provided config. +func New(ctx context.Context, cfg Config) *Client { + if cfg.UpgradeClient == nil { + cfg.UpgradeClient = http.DefaultClient + } + if cfg.StreamingClient == nil { + cfg.StreamingClient = http.DefaultClient + } + if cfg.Logger == nil { + cfg.Logger = abstractlogger.NoopLogger + } + + c := &Client{ + ctx: ctx, + log: cfg.Logger, + + ws: transport.NewWSTransport(ctx, + transport.WithUpgradeClient(cfg.UpgradeClient), + transport.WithLogger(cfg.Logger), + transport.WithPingInterval(cfg.PingInterval), + transport.WithPingTimeout(cfg.PingTimeout), + transport.WithAckTimeout(cfg.AckTimeout), + transport.WithWriteTimeout(cfg.WriteTimeout), + transport.WithReadLimit(cfg.ReadLimit), + ), + sse: transport.NewSSETransport(ctx, cfg.StreamingClient, cfg.Logger), + } + + c.log.Debug("subscriptionClient.New", abstractlogger.String("status", "initialized")) + + return c +} + +// Subscribe creates a new upstream via the appropriate transport. +func (c *Client) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { + if c.ctx.Err() != nil { + return nil, nil, ErrClientClosed + } + + // Route to transport + var source <-chan *common.Message + var cancel func() + var err error + + if opts.Transport == common.TransportSSE { + source, cancel, err = c.sse.Subscribe(ctx, req, opts) + } else { + source, cancel, err = c.ws.Subscribe(ctx, req, opts) + } + + return source, cancel, err +} + +// Stats returns client statistics. +func (c *Client) Stats() Stats { + stats := Stats{ + WSConns: c.ws.ConnCount(), + SSEConns: c.sse.ConnCount(), + } + return stats +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go new file mode 100644 index 0000000000..c42b508e87 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go @@ -0,0 +1,289 @@ +package client + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestClient(t *testing.T) { + t.Run("new creates client with transports", func(t *testing.T) { + c := New(t.Context(), Config{}) + + assert.NotNil(t, c.ws) + assert.NotNil(t, c.sse) + }) + + t.Run("context cancellation is idempotent", func(t *testing.T) { + assert.NotPanics(t, func() { + ctx, cancel := context.WithCancel(context.Background()) + _ = New(ctx, Config{}) + cancel() + cancel() // should not panic + }) + }) + + t.Run("subscribe fails after context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := New(ctx, Config{}) + cancel() + + _, _, err := c.Subscribe(t.Context(), &Request{Query: "subscription { a }"}, Options{ + Endpoint: "ws://localhost/graphql", + }) + + assert.Equal(t, ErrClientClosed, err) + }) +} + +func TestClient_SubscriberDrain(t *testing.T) { + // These tests verify that cancelling all subscriptions properly cleans up all goroutines. + // Connections close themselves when their last subscriber is removed. + + t.Run("subscriber drain cleans up", func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1")) + + server := newTestWSServer(t) + + c := New(t.Context(), Config{}) + + ch, subCancel, err := c.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + + // subscription is working + select { + case msg := <-ch: + require.NotNil(t, msg.Payload) + case <-time.After(time.Second): + t.Fatal("timeout waiting for message") + } + + subCancel() + + // Give ReadLoop goroutine time to exit after connection close + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("subscriber drain cleans up multiple connections", func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1")) + + server := newTestWSServer(t) + + c := New(t.Context(), Config{}) + + cancels := make([]func(), 3) + // Start subscriptions with different headers (forces multiple connections) + for i := range 3 { + headers := http.Header{"X-Request-ID": []string{string(rune('A' + i))}} + ch, subCancel, err := c.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers, + }) + require.NoError(t, err) + cancels[i] = subCancel + + // Drain first message + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + // Should have 3 connections + stats := c.Stats() + require.Equal(t, 3, stats.WSConns) + + for _, fn := range cancels { + fn() + } + + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +func TestClient_CancelSendsComplete(t *testing.T) { + t.Run("cancel sends complete to server", func(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1"), + ) + + completeReceived := make(chan string, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer func() { + _ = conn.CloseNow() + }() + + ctx := r.Context() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + if subMsg["type"] != "subscribe" { + return + } + subID := subMsg["id"].(string) + + // Send a next message + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subID, + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + + // Wait for complete message from client + var completeMsg map[string]any + if err := wsjson.Read(ctx, conn, &completeMsg); err != nil { + return + } + if completeMsg["type"] == "complete" { + completeReceived <- completeMsg["id"].(string) + } + })) + t.Cleanup(server.Close) + + c := New(t.Context(), Config{}) + + ch, cancel, err := c.Subscribe(t.Context(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + + // Wait for first message + select { + case msg := <-ch: + require.NotNil(t, msg.Payload) + case <-time.After(time.Second): + t.Fatal("timeout waiting for message") + } + + // Cancel the subscription - this should send complete to server + cancel() + + // Verify server received complete + select { + case id := <-completeReceived: + assert.NotEmpty(t, id, "complete message should have subscription ID") + case <-time.After(time.Second): + t.Fatal("timeout waiting for complete message on server") + } + + // Wait for ReadLoop goroutine to exit after connection close + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +// Test helper: creates a WebSocket server that sends periodic messages +func newTestWSServer(t *testing.T) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer func() { + _ = conn.CloseNow() + }() + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read messages in background, cancel context when connection closes + go func() { + defer cancel() + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + // Handle subscribe by sending first message + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }() + + // Send periodic messages until context cancelled + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Ignore write errors (connection may be closed) + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + })) + + t.Cleanup(server.Close) + return server +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go new file mode 100644 index 0000000000..70c48654d7 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -0,0 +1,27 @@ +package common + +import ( + "encoding/json" + "errors" +) + +var ErrConnectionClosed = errors.New("connection closed") + +type Message struct { + Payload *ExecutionResult + Err error + Done bool +} + +type ExecutionResult struct { + Data json.RawMessage `json:"data,omitempty"` + Errors json.RawMessage `json:"errors,omitempty"` + Extensions json.RawMessage `json:"extensions,omitempty"` +} + +type Request struct { + Query string `json:"query"` + OperationName string `json:"operationName,omitempty"` + Variables json.RawMessage `json:"variables,omitempty"` + Extensions json.RawMessage `json:"extensions,omitempty"` +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go new file mode 100644 index 0000000000..ab706d02f0 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go @@ -0,0 +1,55 @@ +package common + +import ( + "net/http" +) + +type TransportType string + +const ( + TransportWS TransportType = "ws" + TransportSSE TransportType = "sse" +) + +type WSSubprotocol string + +const ( + SubprotocolAuto WSSubprotocol = "" // Auto, negotiated with the server + SubprotocolGraphQLTransportWS WSSubprotocol = "graphql-transport-ws" // Modern subprotocol + SubprotocolGraphQLWS WSSubprotocol = "graphql-ws" // Legacy subprotocol, deprecated +) + +func (s WSSubprotocol) Subprotocols() []string { + switch s { + case SubprotocolAuto: + return []string{"graphql-transport-ws", "graphql-ws"} + case SubprotocolGraphQLTransportWS: + return []string{"graphql-transport-ws"} + case SubprotocolGraphQLWS: + return []string{"graphql-ws"} + default: + return nil + } +} + +type SSEMethod string + +const ( + SSEMethodAuto SSEMethod = "" // Auto: POST for graphql-sse (default) + SSEMethodPOST SSEMethod = "POST" // POST with JSON body (graphql-sse spec) + SSEMethodGET SSEMethod = "GET" // GET with query parameters (traditional SSE) +) + +type Options struct { + Endpoint string + Headers http.Header + InitPayload map[string]any + Transport TransportType + + // Only affects the WebSocket transport. + WSSubprotocol WSSubprotocol + + // Only affects the SSE transport. + // Defaults to POST (graphql-sse spec). + SSEMethod SSEMethod +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go new file mode 100644 index 0000000000..4338508522 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -0,0 +1,51 @@ +package client + +import ( + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" +) + +// Re-export common types for single-import convenience. + +type ( + Message = common.Message + ExecutionResult = common.ExecutionResult + Request = common.Request + Options = common.Options + TransportType = common.TransportType + WSSubprotocol = common.WSSubprotocol + SSEMethod = common.SSEMethod +) + +// Re-export constants. + +const ( + TransportWS = common.TransportWS + TransportSSE = common.TransportSSE + + SubprotocolAuto = common.SubprotocolAuto + SubprotocolGraphQLTransportWS = common.SubprotocolGraphQLTransportWS + SubprotocolGraphQLWS = common.SubprotocolGraphQLWS + + SSEMethodAuto = common.SSEMethodAuto + SSEMethodPOST = common.SSEMethodPOST + SSEMethodGET = common.SSEMethodGET +) + +// Re-export error types. + +type ( + ErrFailedUpgrade = transport.ErrFailedUpgrade + ErrInvalidSubprotocol = transport.ErrInvalidSubprotocol +) + +// Re-export sentinel errors. + +var ( + ErrConnectionClosed = common.ErrConnectionClosed + ErrConnectionError = protocol.ErrConnectionError + ErrAckTimeout = protocol.ErrAckTimeout + ErrAckNotReceived = protocol.ErrAckNotReceived + ErrSubscriptionExists = transport.ErrSubscriptionExists +) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go new file mode 100644 index 0000000000..80bd884f92 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -0,0 +1,174 @@ +package protocol + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +const ( + gtwsTypeConnectionInit = "connection_init" + gtwsTypeConnectionAck = "connection_ack" + gtwsTypePing = "ping" + gtwsTypePong = "pong" + gtwsTypeSubscribe = "subscribe" + gtwsTypeNext = "next" + gtwsTypeError = "error" + gtwsTypeComplete = "complete" +) + +type outgoingMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Payload any `json:"payload,omitempty"` +} + +type incomingMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +type GraphQLTransportWS struct { + AckTimeout time.Duration +} + +func NewGraphQLTransportWS() *GraphQLTransportWS { + return &GraphQLTransportWS{ + AckTimeout: 30 * time.Second, + } +} + +// Init implements Protocol. +func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + initMsg := outgoingMessage{ + Type: gtwsTypeConnectionInit, + } + if payload != nil { + initMsg.Payload = payload + } + if err := wsjson.Write(ctx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + + timeout := p.AckTimeout + if timeout == 0 { + timeout = 30 * time.Second + } + + ackCtx, ackCancel := context.WithTimeout(ctx, timeout) + defer ackCancel() + + for { + var ackMessage incomingMessage + if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return ErrAckTimeout + } + return fmt.Errorf("read connection_ack: %w", err) + } + + switch ackMessage.Type { + case gtwsTypeConnectionAck: + return nil + case gtwsTypePing: + if err := p.Pong(ctx, conn); err != nil { + return fmt.Errorf("pre-init pong: %w", err) + } + continue + default: + return fmt.Errorf("%w: got %q", ErrAckNotReceived, ackMessage.Type) + } + } +} + +// Ping implements Protocol. +func (p *GraphQLTransportWS) Ping(ctx context.Context, conn *websocket.Conn) error { + msg := outgoingMessage{ + Type: gtwsTypePing, + } + return wsjson.Write(ctx, conn, msg) +} + +// Pong implements Protocol. +func (p *GraphQLTransportWS) Pong(ctx context.Context, conn *websocket.Conn) error { + msg := outgoingMessage{ + Type: gtwsTypePong, + } + return wsjson.Write(ctx, conn, msg) +} + +// Read implements Protocol. +func (p *GraphQLTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*Message, error) { + var raw incomingMessage + if err := wsjson.Read(ctx, conn, &raw); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + return p.decode(raw) +} + +// Subscribe implements Protocol. +func (p *GraphQLTransportWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + msg := outgoingMessage{ + ID: id, + Type: gtwsTypeSubscribe, + Payload: req, + } + return wsjson.Write(ctx, conn, msg) +} + +// Unsubscribe implements Protocol. +func (p *GraphQLTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + msg := outgoingMessage{ + ID: id, + Type: gtwsTypeComplete, + } + return wsjson.Write(ctx, conn, msg) +} + +func (p *GraphQLTransportWS) decode(raw incomingMessage) (*Message, error) { + msg := &Message{ + ID: raw.ID, + } + + switch raw.Type { + case gtwsTypeNext: + msg.Type = MessageData + if raw.Payload != nil { + var resp common.ExecutionResult + if err := json.Unmarshal(raw.Payload, &resp); err != nil { + return nil, fmt.Errorf("unmarshal next payload: %w", err) + } + msg.Payload = &resp + } + case gtwsTypeError: + msg.Type = MessageError + if raw.Payload != nil { + msg.Payload = &common.ExecutionResult{Errors: raw.Payload} + } + + case gtwsTypeComplete: + msg.Type = MessageComplete + + case gtwsTypePing: + msg.Type = MessagePing + + case gtwsTypePong: + msg.Type = MessagePong + + default: + return nil, fmt.Errorf("unknown message type: %s", raw.Type) + } + + return msg, nil +} + +var _ Protocol = (*GraphQLTransportWS)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go new file mode 100644 index 0000000000..ca4649ef53 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go @@ -0,0 +1,378 @@ +package protocol + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestGraphQLTransportWS_Init(t *testing.T) { + t.Parallel() + + t.Run("sends connection_init and receives connection_ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + + err := p.Init(t.Context(), conn, map[string]any{"secret": "token"}) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "token", payload["secret"]) + }) + }) + + t.Run("returns error when ack times out", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + time.Sleep(500 * time.Millisecond) + }) + + conn := dialGTWS(t, server) + + p := &GraphQLTransportWS{AckTimeout: 50 * time.Millisecond} + err := p.Init(t.Context(), conn, nil) + + require.ErrorIs(t, err, ErrAckTimeout) + }) + + t.Run("handles ping before ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 2) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg1 map[string]any + _ = wsjson.Read(ctx, conn, &msg1) + received <- msg1 + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ping"}) + + var msg2 map[string]any + _ = wsjson.Read(ctx, conn, &msg2) + received <- msg2 + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Init(t.Context(), conn, nil) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + }) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "pong", msg["type"]) + }) + }) + + t.Run("returns error on unexpected message type", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "error"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Init(t.Context(), conn, nil) + + assert.ErrorIs(t, err, ErrAckNotReceived) + }) +} + +func TestGraphQLWS_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends subscribe message with query and variables", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Subscribe(t.Context(), conn, "sub-1", &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + }) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "subscribe", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "subscription { test }", payload["query"]) + + vars, _ := payload["variables"].(map[string]any) + assert.Equal(t, float64(123), vars["id"]) + }) + }) +} + +func TestGraphQLWS_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("sends complete message with subscription id", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Unsubscribe(t.Context(), conn, "sub-1") + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "complete", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + }) + }) +} + +func TestGraphQLWS_Read(t *testing.T) { + t.Parallel() + + t.Run("decodes next message with data payload", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "next", + "payload": map[string]any{ + "data": map[string]any{"value": 42}, + }, + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageData, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("decodes error message with graphql errors", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "error", + "payload": []map[string]any{ + {"message": "something went wrong"}, + }, + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "something went wrong") + }) + + t.Run("decodes complete message", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "complete", + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageComplete, msg.Type) + }) + + t.Run("decodes ping message", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ping"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessagePing, msg.Type) + }) + + t.Run("returns error for unknown message type", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "unknown"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + _, err := p.Read(t.Context(), conn) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown") + }) +} + +func TestGraphQLTransportWS_PingPong(t *testing.T) { + t.Parallel() + + t.Run("sends ping message", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Ping(t.Context(), conn) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "ping", msg["type"]) + }) + }) + + t.Run("sends pong message", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Pong(t.Context(), conn) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "pong", msg["type"]) + }) + }) +} + +func newGTWSTestServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + + return server +} + +func dialGTWS(t *testing.T, server *httptest.Server) *websocket.Conn { + t.Helper() + + conn, _, err := websocket.Dial(t.Context(), server.URL, &websocket.DialOptions{ //nolint:bodyclose + Subprotocols: []string{"graphql-transport-ws"}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }) + + return conn +} + +func awaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...any) { + t.Helper() + + select { + case args := <-ch: + f(t, args) + case <-time.After(timeout): + require.Fail(t, "unable to receive message before timeout", msgAndArgs...) + } +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go new file mode 100644 index 0000000000..aad78508e3 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -0,0 +1,175 @@ +package protocol + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +const ( + gwsTypeConnectionInit = "connection_init" + gwsTypeConnectionAck = "connection_ack" + gwsTypeConnectionError = "connection_error" + gwsTypeConnectionKeepAlive = "ka" + gwsTypeStart = "start" + gwsTypeData = "data" + gwsTypeError = "error" + gwsTypeComplete = "complete" + gwsTypeStop = "stop" +) + +// GraphQLWS implements the legacy graphql-ws protocol. +// See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +type GraphQLWS struct { + AckTimeout time.Duration +} + +func NewGraphQLWS() *GraphQLWS { + return &GraphQLWS{ + AckTimeout: 30 * time.Second, + } +} + +// Init implements Protocol. +func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + initMsg := outgoingMessage{ + Type: gwsTypeConnectionInit, + } + if payload != nil { + initMsg.Payload = payload + } + if err := wsjson.Write(ctx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + + timeout := p.AckTimeout + if timeout == 0 { + timeout = 30 * time.Second + } + + ackCtx, ackCancel := context.WithTimeout(ctx, timeout) + defer ackCancel() + + for { + var ackMessage incomingMessage + if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return ErrAckTimeout + } + return fmt.Errorf("read connection_ack: %w", err) + } + + switch ackMessage.Type { + case gwsTypeConnectionAck: + return nil + case gwsTypeConnectionKeepAlive: + // Keep-alive messages can arrive before ack, ignore them + continue + case gwsTypeConnectionError: + var errPayload map[string]any + if ackMessage.Payload != nil { + // If this fails, the error will have nil errors anyway, handling it does nothing unique + _ = json.Unmarshal(ackMessage.Payload, &errPayload) + } + return fmt.Errorf("%w: %v", ErrConnectionError, errPayload) + default: + return fmt.Errorf("%w: got %q", ErrAckNotReceived, ackMessage.Type) + } + } +} + +// Subscribe implements Protocol. +func (p *GraphQLWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + msg := outgoingMessage{ + ID: id, + Type: gwsTypeStart, + Payload: req, + } + return wsjson.Write(ctx, conn, msg) +} + +// Unsubscribe implements Protocol. +func (p *GraphQLWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + msg := outgoingMessage{ + ID: id, + Type: gwsTypeStop, + } + return wsjson.Write(ctx, conn, msg) +} + +// Read implements Protocol. +func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*Message, error) { + var raw incomingMessage + if err := wsjson.Read(ctx, conn, &raw); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + return p.decode(raw) +} + +// Ping implements Protocol. +// Legacy protocol doesn't support client-initiated ping, this is a no-op. +func (p *GraphQLWS) Ping(ctx context.Context, conn *websocket.Conn) error { + // Legacy protocol doesn't have client ping - only server sends ka + return nil +} + +// Pong implements Protocol. +// Legacy protocol doesn't support pong messages, this is a no-op. +func (p *GraphQLWS) Pong(ctx context.Context, conn *websocket.Conn) error { + // Legacy protocol doesn't have pong + return nil +} + +func (p *GraphQLWS) decode(raw incomingMessage) (*Message, error) { + msg := &Message{ + ID: raw.ID, + } + + switch raw.Type { + case gwsTypeData: + msg.Type = MessageData + if raw.Payload != nil { + var resp common.ExecutionResult + if err := json.Unmarshal(raw.Payload, &resp); err != nil { + return nil, fmt.Errorf("unmarshal data payload: %w", err) + } + msg.Payload = &resp + } + + case gwsTypeError: + msg.Type = MessageError + if raw.Payload != nil { + msg.Payload = &common.ExecutionResult{Errors: raw.Payload} + } + + case gwsTypeComplete: + msg.Type = MessageComplete + + case gwsTypeConnectionKeepAlive: + // Map keep-alive to ping for consistent handling + msg.Type = MessagePing + + case gwsTypeConnectionError: + msg.Type = MessageError + var errPayload map[string]any + if raw.Payload != nil { + _ = json.Unmarshal(raw.Payload, &errPayload) + } + msg.Err = fmt.Errorf("%w: %v", ErrConnectionError, errPayload) + + default: + return nil, fmt.Errorf("unknown message type: %s", raw.Type) + } + + return msg, nil +} + +var _ Protocol = (*GraphQLWS)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go new file mode 100644 index 0000000000..04e5fe51a7 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go @@ -0,0 +1,375 @@ +package protocol + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestGraphQLWS_Init(t *testing.T) { + t.Parallel() + + t.Run("sends connection_init and receives connection_ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + + err := p.Init(t.Context(), conn, map[string]any{"secret": "token"}) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "token", payload["secret"]) + }) + }) + + t.Run("returns error when ack times out", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + time.Sleep(500 * time.Millisecond) + }) + + conn := dialGWS(t, server) + + p := &GraphQLWS{AckTimeout: 50 * time.Millisecond} + err := p.Init(t.Context(), conn, nil) + + require.ErrorIs(t, err, ErrAckTimeout) + }) + + t.Run("handles keep-alive before ack", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg1 map[string]any + _ = wsjson.Read(ctx, conn, &msg1) // connection_init + + // Send keep-alive before ack + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + + // Then send ack + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + require.NoError(t, err) + }) + + t.Run("returns error on connection_error", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "connection_error", + "payload": map[string]any{"message": "auth failed"}, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + + require.ErrorIs(t, err, ErrConnectionError) + assert.Contains(t, err.Error(), "auth failed") + }) + + t.Run("returns error on unexpected message type", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "error"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + + assert.ErrorIs(t, err, ErrAckNotReceived) + }) +} + +func TestGraphQLWSLegacy_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends start message with query and variables", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Subscribe(t.Context(), conn, "sub-1", &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + }) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "start", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "subscription { test }", payload["query"]) + + vars, _ := payload["variables"].(map[string]any) + assert.Equal(t, float64(123), vars["id"]) + }) + }) +} + +func TestGraphQLWSLegacy_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("sends stop message with subscription id", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Unsubscribe(t.Context(), conn, "sub-1") + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "stop", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + }) + }) +} + +func TestGraphQLWSLegacy_Read(t *testing.T) { + t.Parallel() + + t.Run("decodes data message with payload", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "data", + "payload": map[string]any{ + "data": map[string]any{"value": 42}, + }, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageData, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("decodes error message with graphql errors", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "error", + "payload": []map[string]any{ + {"message": "something went wrong"}, + }, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "something went wrong") + }) + + t.Run("decodes complete message", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "complete", + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageComplete, msg.Type) + }) + + t.Run("decodes keep-alive message as ping", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessagePing, msg.Type) + }) + + t.Run("decodes connection_error message", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "connection_error", + "payload": map[string]any{"reason": "session expired"}, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.Error(t, msg.Err) + assert.Contains(t, msg.Err.Error(), "session expired") + }) + + t.Run("returns error for unknown message type", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "unknown", + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + _, err := p.Read(t.Context(), conn) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown") + }) +} + +func TestGraphQLWSLegacy_PingPong(t *testing.T) { + t.Parallel() + + t.Run("ping is a no-op for legacy protocol", func(t *testing.T) { + t.Parallel() + + // Legacy protocol doesn't support client-initiated ping + p := NewGraphQLWS() + + // This should not error, just be a no-op + err := p.Ping(context.Background(), nil) + require.NoError(t, err) + }) + + t.Run("pong is a no-op for legacy protocol", func(t *testing.T) { + t.Parallel() + + // Legacy protocol doesn't support pong + p := NewGraphQLWS() + + // This should not error, just be a no-op + err := p.Pong(context.Background(), nil) + require.NoError(t, err) + }) +} + +func newGWSTestServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + + return server +} + +func dialGWS(t *testing.T, server *httptest.Server) *websocket.Conn { + t.Helper() + + conn, _, err := websocket.Dial(t.Context(), server.URL, &websocket.DialOptions{ //nolint:bodyclose + Subprotocols: []string{"graphql-ws"}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }) + + return conn +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go new file mode 100644 index 0000000000..e54cfb6a30 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -0,0 +1,81 @@ +package protocol + +import ( + "context" + "errors" + + "github.com/coder/websocket" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +type Protocol interface { + Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error + + Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error + + Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error + + Read(ctx context.Context, conn *websocket.Conn) (*Message, error) + + Ping(ctx context.Context, conn *websocket.Conn) error + + Pong(ctx context.Context, conn *websocket.Conn) error +} + +var ( + ErrAckTimeout = errors.New("connection_ack timeout") + ErrAckNotReceived = errors.New("expected connection_ack") + ErrConnectionError = errors.New("connection error from server") +) + +type Message struct { + ID string + Type MessageType + Payload *common.ExecutionResult + Err error +} + +func (m *Message) IntoClientMessage() *common.Message { + switch m.Type { + case MessageData: + return &common.Message{Payload: m.Payload} + case MessageError: + if m.Payload != nil { + return &common.Message{Payload: m.Payload, Done: true} + } + return &common.Message{Err: m.Err, Done: true} + case MessageComplete: + return &common.Message{Done: true} + default: + return &common.Message{} + } +} + +// MessageType identifies the message type. +type MessageType int + +const ( + MessageData MessageType = iota + MessageError + MessageComplete + MessagePing + MessagePong +) + +func (t MessageType) String() string { + switch t { + case MessageData: + return "data" + case MessageError: + return "error" + case MessageComplete: + return "complete" + case MessagePing: + return "ping" + case MessagePong: + return "pong" + default: + return "unknown" + } +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go new file mode 100644 index 0000000000..3e3b6892dc --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go @@ -0,0 +1,190 @@ +package transport + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "sync/atomic" + + "github.com/r3labs/sse/v2" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +var ( + headerData = []byte("data:") + headerEvent = []byte("event:") +) + +// sseConnection handles a single SSE subscription stream. +type sseConnection struct { + resp *http.Response + ch chan *common.Message + done chan struct{} + closed atomic.Bool +} + +func newSSEConnection(resp *http.Response) *sseConnection { + return &sseConnection{ + resp: resp, + ch: make(chan *common.Message, 8), + done: make(chan struct{}), + } +} + +// readLoop reads SSE events from the response body and sends them to the channel. +func (c *sseConnection) readLoop() { + defer c.cleanup() + + reader := sse.NewEventStreamReader(c.resp.Body, 1<<16) // 64KB + + for { + if c.closed.Load() { + return + } + + eventBytes, err := reader.ReadEvent() + if err != nil { + if err != io.EOF { + c.sendError(err) + } + return + } + + // Parse the raw event bytes into event type and data + eventType, data := c.parseEventBytes(eventBytes) + + // Skip empty events (e.g., keep-alive comments) + if eventType == "" && data == nil { + continue + } + + msg := c.parseEvent(eventType, data) + + if c.closed.Load() { + return + } + select { + case c.ch <- msg: + case <-c.done: + return + } + + if msg.Done { + return + } + } +} + +// parseEventBytes extracts the event type and data from raw SSE event bytes. +// Based on r3labs/sse's processEvent but simplified for our needs. +func (c *sseConnection) parseEventBytes(msg []byte) (eventType string, data []byte) { + if len(msg) == 0 { + return "", nil + } + + // Split by newlines (normalize CR/LF) + for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { + switch { + case bytes.HasPrefix(line, headerEvent): + eventType = string(trimHeader(len(headerEvent), line)) + + case bytes.HasPrefix(line, headerData): + // The spec allows for multiple data fields per event, concatenated with "\n" + data = append(data, trimHeader(len(headerData), line)...) + data = append(data, '\n') + + case bytes.Equal(line, []byte("data")): + // A line that simply contains "data" should be treated as empty data + data = append(data, '\n') + + // Comments (lines starting with ':') are already filtered by EventStreamReader + } + } + + // Trim the trailing "\n" per SSE spec + data = bytes.TrimSuffix(data, []byte("\n")) + + return eventType, data +} + +// trimHeader removes the header prefix and optional leading space. +func trimHeader(size int, data []byte) []byte { + if len(data) < size { + return data + } + + data = data[size:] + // Remove optional leading whitespace (single space after colon) + if len(data) > 0 && data[0] == ' ' { + data = data[1:] + } + return data +} + +// parseEvent converts parsed SSE event data into a shared.Message. +func (c *sseConnection) parseEvent(eventType string, data []byte) *common.Message { + switch eventType { + case "next": + var resp common.ExecutionResult + if err := json.Unmarshal(data, &resp); err != nil { + return &common.Message{ + Err: err, + Done: true, + } + } + return &common.Message{Payload: &resp} + + case "error": + return &common.Message{ + Payload: &common.ExecutionResult{Errors: data}, + Done: true, + } + + case "complete": + return &common.Message{Done: true} + + default: + // Unknown event type or no event type specified - treat as data + // This handles servers that send data without an event type + if len(data) == 0 { + return &common.Message{Done: true} + } + var resp common.ExecutionResult + if err := json.Unmarshal(data, &resp); err != nil { + return &common.Message{ + Err: err, + Done: true, + } + } + return &common.Message{Payload: &resp} + } +} + +func (c *sseConnection) sendError(err error) { + if c.closed.Load() { + return + } + select { + case c.ch <- &common.Message{Err: err, Done: true}: + case <-c.done: + } +} + +func (c *sseConnection) cleanup() { + c.closed.Store(true) + + c.resp.Body.Close() + close(c.ch) // Close channel so fanout exits +} + +// closeConn terminates the SSE connection. +func (c *sseConnection) closeConn() { + if !c.closed.CompareAndSwap(false, true) { + return + } + + close(c.done) + c.resp.Body.Close() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go new file mode 100644 index 0000000000..c4f4cf3650 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go @@ -0,0 +1,152 @@ +package transport + +import ( + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSEConnection_ReadLoop(t *testing.T) { + t.Run("reads and parses SSE events", func(t *testing.T) { + body := io.NopCloser(strings.NewReader( + "event: next\ndata: {\"data\":{\"time\":\"12:00\"}}\n\n", + )) + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + go conn.readLoop() + + msg := <-conn.ch + require.NotNil(t, msg.Payload) + // Data field contains the raw "data" value from GraphQL response + assert.JSONEq(t, `{"time":"12:00"}`, string(msg.Payload.Data)) + }) + + t.Run("closes channel on EOF", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + go conn.readLoop() + + select { + case _, ok := <-conn.ch: + assert.False(t, ok, "channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatal("channel not closed on EOF") + } + }) + + t.Run("sends error on read failure", func(t *testing.T) { + body := &errorReader{err: io.ErrUnexpectedEOF} + resp := &http.Response{Body: io.NopCloser(body)} + conn := newSSEConnection(resp) + + go conn.readLoop() + + msg := <-conn.ch + require.Error(t, msg.Err) + require.True(t, msg.Done) + }) + + t.Run("stops on complete event", func(t *testing.T) { + body := io.NopCloser(strings.NewReader( + "event: next\ndata: {\"data\":{}}\n\n" + + "event: complete\ndata:\n\n" + + "event: next\ndata: {\"data\":{}}\n\n", // Should not receive this + )) + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + go conn.readLoop() + + // First message + msg1 := <-conn.ch + assert.NotNil(t, msg1.Payload) + assert.False(t, msg1.Done) + + // Complete message + msg2 := <-conn.ch + assert.True(t, msg2.Done) + + // Channel should close, no third message + select { + case _, ok := <-conn.ch: + assert.False(t, ok, "channel should be closed after complete") + case <-time.After(100 * time.Millisecond): + t.Fatal("channel not closed after complete") + } + }) +} + +func TestSSEConnection_Close(t *testing.T) { + t.Run("closes channel and body", func(t *testing.T) { + pr, pw := io.Pipe() + body := &trackingCloser{Reader: pr} + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + go conn.readLoop() + + conn.closeConn() + pw.Close() // Ensure pipe is fully closed + + // Channel close signals cleanup completed + select { + case _, ok := <-conn.ch: + require.False(t, ok, "channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatal("channel should be closed (timeout)") + } + + assert.True(t, body.closed.Load(), "body should be closed") + }) + + t.Run("is idempotent", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + conn.closeConn() + conn.closeConn() // second call is a no-op + }) +} + +func TestSSEConnection_Channel(t *testing.T) { + t.Run("returns buffered channel", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{Body: body} + conn := newSSEConnection(resp) + + ch := conn.ch + assert.NotNil(t, ch) + assert.Equal(t, 8, cap(ch)) + }) +} + +// errorReader always returns an error +type errorReader struct { + err error +} + +func (r *errorReader) Read(_ []byte) (int, error) { + return 0, r.err +} + +// trackingCloser tracks if Close was called +type trackingCloser struct { + io.Reader + + closed atomic.Bool +} + +func (c *trackingCloser) Close() error { + c.closed.Store(true) + return nil +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go new file mode 100644 index 0000000000..4080d9367b --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -0,0 +1,259 @@ +package transport + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +// SSETransport implements the Transport interface using Server-Sent Events. +// Unlike WebSocket, each subscription creates a separate HTTP request. +// TCP connection reuse is handled by http.Client's connection pool. +// +// Supports both POST (graphql-sse spec) and GET (traditional SSE) methods. +type SSETransport struct { + ctx context.Context + client *http.Client + log abstractlogger.Logger + + mu sync.Mutex + conns map[*sseConnection]struct{} +} + +// NewSSETransport creates a new SSETransport with the provided http.Client. +// The transport will automatically close all connections when ctx is cancelled. +func NewSSETransport(ctx context.Context, client *http.Client, log abstractlogger.Logger) *SSETransport { + if log == nil { + log = abstractlogger.NoopLogger + } + + t := &SSETransport{ + ctx: ctx, + client: client, + log: log, + conns: make(map[*sseConnection]struct{}), + } + + context.AfterFunc(ctx, t.closeAll) + + return t +} + +// Subscribe initiates a GraphQL subscription over SSE. +// Each call creates a new HTTP request (no multiplexing). +// +// The HTTP method is determined by opts.SSEMethod: +// - SSEMethodAuto or SSEMethodPOST: POST with JSON body (graphql-sse spec) +// - SSEMethodGET: GET with query parameters (traditional SSE) +func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { + var httpReq *http.Request + var err error + + method := opts.SSEMethod + if method == common.SSEMethodAuto { + method = common.SSEMethodPOST // Default to POST (graphql-sse spec) + } + + t.log.Debug("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("method", string(method)), + ) + + // Use request context, but with transport requestCancel + requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) + + // Attach cancel to transport context + context.AfterFunc(t.ctx, requestCancel) + context.AfterFunc(ctx, requestCancel) + + switch method { + case common.SSEMethodPOST: + httpReq, err = buildPOSTRequest(requestCtx, req, opts) + case common.SSEMethodGET: + httpReq, err = buildGETRequest(requestCtx, req, opts) + default: + return nil, nil, fmt.Errorf("unsupported SSE method: %s", method) + } + + if err != nil { + return nil, nil, err + } + + // Execute request + resp, err := t.client.Do(httpReq) + if err != nil { + t.log.Error("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Error(err), + ) + return nil, nil, fmt.Errorf("execute request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + t.log.Error("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Int("status", resp.StatusCode), + ) + if len(body) > 0 { + return nil, nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } + return nil, nil, fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + + // Verify content type (should be text/event-stream) + if err := t.validateContentType(resp); err != nil { + resp.Body.Close() + return nil, nil, err + } + + t.log.Debug("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("status", "connected"), + ) + + // Create connection + conn := newSSEConnection(resp) + + t.mu.Lock() + t.conns[conn] = struct{}{} + t.mu.Unlock() + + go conn.readLoop() + + cancelFn := func() { + conn.closeConn() + t.removeConn(conn) + } + + return conn.ch, cancelFn, nil +} + +// buildPOSTRequest creates a POST request with JSON body (graphql-sse spec). +func buildPOSTRequest(ctx context.Context, req *common.Request, opts common.Options) (*http.Request, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, opts.Endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + // Add custom headers + maps.Copy(httpReq.Header, opts.Headers) + + return httpReq, nil +} + +// buildGETRequest creates a GET request with query parameters (traditional SSE). +func buildGETRequest(ctx context.Context, req *common.Request, opts common.Options) (*http.Request, error) { + // Parse the endpoint URL + u, err := url.Parse(opts.Endpoint) + if err != nil { + return nil, fmt.Errorf("parse endpoint: %w", err) + } + + // Build query parameters + q := u.Query() + q.Set("query", req.Query) + + if len(req.Variables) > 0 { + varsJSON, err := json.Marshal(req.Variables) + if err != nil { + return nil, fmt.Errorf("marshal variables: %w", err) + } + q.Set("variables", string(varsJSON)) + } + + if req.OperationName != "" { + q.Set("operationName", req.OperationName) + } + + if len(req.Extensions) > 0 { + extJSON, err := json.Marshal(req.Extensions) + if err != nil { + return nil, fmt.Errorf("marshal extensions: %w", err) + } + q.Set("extensions", string(extJSON)) + } + + u.RawQuery = q.Encode() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + // Add custom headers + maps.Copy(httpReq.Header, opts.Headers) + + return httpReq, nil +} + +// validateContentType checks that the response has the correct content type. +func (t *SSETransport) validateContentType(resp *http.Response) error { + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + return nil // Allow missing content-type + } + + // Check if it starts with text/event-stream (may include charset) + if strings.HasPrefix(contentType, "text/event-stream") { + return nil + } + + return fmt.Errorf("unexpected content-type: %s", contentType) +} + +func (t *SSETransport) removeConn(conn *sseConnection) { + t.mu.Lock() + delete(t.conns, conn) + t.mu.Unlock() +} + +// closeAll terminates all active SSE connections. Called automatically when context is cancelled. +func (t *SSETransport) closeAll() { + t.mu.Lock() + conns := make([]*sseConnection, 0, len(t.conns)) + for conn := range t.conns { + conns = append(conns, conn) + } + t.conns = make(map[*sseConnection]struct{}) + t.mu.Unlock() + + t.log.Debug("sseTransport.closeAll", + abstractlogger.Int("connections", len(conns)), + ) + + for _, conn := range conns { + conn.closeConn() + } +} + +// ConnCount returns the number of active SSE connections. +func (t *SSETransport) ConnCount() int { + t.mu.Lock() + defer t.mu.Unlock() + return len(t.conns) +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go new file mode 100644 index 0000000000..a5236f7059 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -0,0 +1,862 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestSSETransport_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends POST request and receives messages", func(t *testing.T) { + t.Parallel() + + var receivedBody map[string]any + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + // Verify POST method + assert.Equal(t, http.MethodPost, r.Method) + + // Verify headers + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "text/event-stream", r.Header.Get("Accept")) + assert.Equal(t, "no-cache", r.Header.Get("Cache-Control")) + + // Read and verify body + body, _ := io.ReadAll(r.Body) + assert.NoError(t, json.Unmarshal(body, &receivedBody)) + + // Send SSE response + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 42}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + OperationName: "TestSub", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportSSE, + }) + require.NoError(t, err) + defer cancel() + + // Verify request body + assert.Equal(t, "subscription { test }", receivedBody["query"]) + assert.Equal(t, float64(123), receivedBody["variables"].(map[string]any)["id"]) + assert.Equal(t, "TestSub", receivedBody["operationName"]) + + // Receive data message + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + + // Receive complete message + msg = receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("passes custom headers", func(t *testing.T) { + t.Parallel() + + var receivedAuth string + var receivedCustom string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + receivedCustom = r.Header.Get("X-Custom-Header") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + headers := http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Custom-Header": []string{"custom-value"}, + } + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Headers: headers, + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, "Bearer token123", receivedAuth) + assert.Equal(t, "custom-value", receivedCustom) + }) + + t.Run("handles next event with data", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"user\": {\"name\": \"Alice\"}}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { user { name } }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "Alice") + assert.False(t, msg.Done) + }) + + t.Run("handles error event", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + fmt.Fprintf(w, "event: error\ndata: [{\"message\": \"Something went wrong\"}]\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "Something went wrong") + }) + + t.Run("handles complete event", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + assert.Nil(t, msg.Err) + assert.Nil(t, msg.Payload) + }) + + t.Run("handles multi-line data", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + // Multi-line data per SSE spec + fmt.Fprintf(w, "event: next\n") + fmt.Fprintf(w, "data: {\"data\": {\n") + fmt.Fprintf(w, "data: \"value\": 42\n") + fmt.Fprintf(w, "data: }}\n") + fmt.Fprintf(w, "\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + // The multi-line data is joined with newlines + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("ignores SSE comments", func(t *testing.T) { + t.Parallel() + + var messageCount atomic.Int32 + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + + // Send some keep-alive comments + fmt.Fprintf(w, ": keep-alive\n") + fmt.Fprintf(w, ": another comment\n") + flusher.Flush() + + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 1}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, ": more keep-alive\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + // Should only receive 2 messages (next + complete), not comments + for msg := range ch { + messageCount.Add(1) + if msg.Done { + break + } + } + + assert.Equal(t, int32(2), messageCount.Load()) + }) + + t.Run("cancel closes connection", func(t *testing.T) { + t.Parallel() + + serverClosed := make(chan struct{}) + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + // Wait for client to disconnect + <-r.Context().Done() + close(serverClosed) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + + // Receive first message + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, 1, tr.ConnCount()) + + // Cancel should close the connection + cancel() + + select { + case <-serverClosed: + // Good, server detected disconnect + case <-time.After(time.Second): + t.Fatal("server did not detect disconnect") + } + + assert.Equal(t, 0, tr.ConnCount()) + }) + + t.Run("context cancellation stops subscription", func(t *testing.T) { + t.Parallel() + + serverClosed := make(chan struct{}) + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + <-r.Context().Done() + close(serverClosed) + }) + + transportCtx, transportCancel := context.WithCancel(context.Background()) + + tr := NewSSETransport(transportCtx, http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(transportCtx, &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + _ = receiveWithTimeout(t, ch, time.Second) + + // Cancel context + transportCancel() + + select { + case <-serverClosed: + case <-time.After(10 * time.Second): + t.Fatal("server did not detect context cancellation") + } + }) + + t.Run("handles non-200 response", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "401") + }) + + t.Run("handles non-200 with body", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal server error")) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "500") + }) + + t.Run("creates separate connection per subscription", func(t *testing.T) { + t.Parallel() + + var reqCount atomic.Int32 + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + // Keep connection open + <-r.Context().Done() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + opts := common.Options{Endpoint: server.URL} + + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + + receiveWithTimeout(t, ch1, time.Second) + receiveWithTimeout(t, ch2, time.Second) + + // SSE creates separate HTTP requests (no multiplexing) + assert.Equal(t, int32(2), reqCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + + cancel1() + cancel2() + }) + + t.Run("handles server closing stream", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 1}}\n\n") + flusher.Flush() + + // Server closes without sending complete + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.NotNil(t, msg.Payload) + + // Channel should close when server closes stream + select { + case _, ok := <-ch: + assert.False(t, ok, "channel should be closed") + case <-time.After(time.Second): + t.Fatal("channel should have been closed") + } + }) + + t.Run("handles data without event type", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + // Some servers send data without explicit event type + fmt.Fprintf(w, "data: {\"data\": {\"value\": 99}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "99") + }) +} + +func TestSSETransport_ContextCancellation(t *testing.T) { + t.Parallel() + + t.Run("context cancellation closes all connections", func(t *testing.T) { + t.Parallel() + + var closedCount atomic.Int32 + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + <-r.Context().Done() + closedCount.Add(1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + tr := NewSSETransport(ctx, http.DefaultClient, nil) + + opts := common.Options{Endpoint: server.URL} + + ch1, _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + + ch2, _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + + receiveWithTimeout(t, ch1, time.Second) + receiveWithTimeout(t, ch2, time.Second) + + assert.Equal(t, 2, tr.ConnCount()) + + cancel() + + assert.Eventually(t, func() bool { + return closedCount.Load() == 2 + }, time.Second, 10*time.Millisecond) + + assert.Equal(t, 0, tr.ConnCount()) + }) +} + +func TestSSETransport_CustomClient(t *testing.T) { + t.Parallel() + + t.Run("uses custom http client", func(t *testing.T) { + t.Parallel() + + var customHeaderReceived string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + customHeaderReceived = r.Header.Get("X-Custom-Client") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + // Custom client with transport that adds a header + customClient := &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: http.Header{ + "X-Custom-Client": []string{"test-client"}, + }, + }, + } + + tr := NewSSETransport(t.Context(), customClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, "test-client", customHeaderReceived) + }) +} + +// Test helpers + +func newSSEServer(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(handler)) + t.Cleanup(server.Close) + + return server +} + +// headerTransport is a custom RoundTripper that adds headers to requests +type headerTransport struct { + base http.RoundTripper + headers http.Header +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + maps.Copy(req.Header, t.headers) + return t.base.RoundTrip(req) +} + +func TestSSETransport_ContentTypeValidation(t *testing.T) { + t.Parallel() + + t.Run("accepts text/event-stream with charset", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("rejects non-SSE content type", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"error": "not sse"}`)) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL}) + + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "content-type") || strings.Contains(err.Error(), "Content-Type")) + }) +} + +func TestSSETransport_GETMethod(t *testing.T) { + t.Parallel() + + t.Run("sends GET request with query parameters", func(t *testing.T) { + t.Parallel() + + var receivedMethod string + var receivedQuery string + var receivedVariables string + var receivedOperationName string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + receivedQuery = r.URL.Query().Get("query") + receivedVariables = r.URL.Query().Get("variables") + receivedOperationName = r.URL.Query().Get("operationName") + + // Verify no body for GET + body, _ := io.ReadAll(r.Body) + assert.Empty(t, body) + + // Verify headers + assert.Equal(t, "text/event-stream", r.Header.Get("Accept")) + assert.Equal(t, "no-cache", r.Header.Get("Cache-Control")) + assert.Empty(t, r.Header.Get("Content-Type")) // No content-type for GET + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 42}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + OperationName: "TestSub", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + }) + require.NoError(t, err) + defer cancel() + + // Verify GET method and query params + assert.Equal(t, http.MethodGet, receivedMethod) + assert.Equal(t, "subscription { test }", receivedQuery) + assert.Equal(t, `{"id":123}`, receivedVariables) + assert.Equal(t, "TestSub", receivedOperationName) + + // Receive data message + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + + // Receive complete message + msg = receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("GET preserves existing query parameters", func(t *testing.T) { + t.Parallel() + + var receivedToken string + var receivedQuery string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.URL.Query().Get("token") + receivedQuery = r.URL.Query().Get("query") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL + "?token=abc123", + SSEMethod: common.SSEMethodGET, + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, "abc123", receivedToken) + assert.Equal(t, "subscription { test }", receivedQuery) + }) + + t.Run("GET passes custom headers", func(t *testing.T) { + t.Parallel() + + var receivedAuth string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + headers := http.Header{ + "Authorization": []string{"Bearer token123"}, + } + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + Headers: headers, + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, "Bearer token123", receivedAuth) + }) + + t.Run("GET omits empty variables and operationName", func(t *testing.T) { + t.Parallel() + + var hasVariables bool + var hasOperationName bool + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + hasVariables = r.URL.Query().Has("variables") + hasOperationName = r.URL.Query().Has("operationName") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + // No variables or operationName + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.False(t, hasVariables, "variables should not be in query params") + assert.False(t, hasOperationName, "operationName should not be in query params") + }) +} + +func TestSSETransport_MethodDefault(t *testing.T) { + t.Parallel() + + t.Run("defaults to POST when SSEMethod is auto", func(t *testing.T) { + t.Parallel() + + var receivedMethod string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodAuto, // or just omit it + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, http.MethodPost, receivedMethod) + }) + + t.Run("explicit POST method works", func(t *testing.T) { + t.Parallel() + + var receivedMethod string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodPOST, + }) + require.NoError(t, err) + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + + assert.Equal(t, http.MethodPost, receivedMethod) + }) +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go new file mode 100644 index 0000000000..a73cd376c1 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go @@ -0,0 +1,14 @@ +package transport + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +// Transport defines the interface for subscription transports. +// A transport is responsible for managing the full connection to the upstream server. +type Transport interface { + Subscribe(ctx context.Context, req *common.Request, opts common.Options) (results <-chan *common.Message, cancel func(), err error) + Close() error +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go new file mode 100644 index 0000000000..530f46c96e --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go @@ -0,0 +1,19 @@ +package transport + +import ( + "testing" + "time" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func receiveWithTimeout(t *testing.T, ch <-chan *common.Message, timeout time.Duration) *common.Message { + t.Helper() + select { + case msg := <-ch: + return msg + case <-time.After(timeout): + t.Fatal("timeout waiting for message") + return nil + } +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go new file mode 100644 index 0000000000..e1627e54a9 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -0,0 +1,309 @@ +package transport + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" +) + +var ( + ErrSubscriptionExists = errors.New("subscription ID already exists") + + defaultWriteTimeout = 5 * time.Second + defaultReadLimit = int64(1024 * 1024) // 1MB +) + +type wsConnectionOptions struct { + logger abstractlogger.Logger + writeTimeout time.Duration + onEmpty func() +} + +// wsConnectionOption configures a wsConnection. +type wsConnectionOption func(*wsConnectionOptions) + +// withConnLogger sets the logger for connection-level debug output. +func withConnLogger(l abstractlogger.Logger) wsConnectionOption { + return func(o *wsConnectionOptions) { + if l != nil { + o.logger = l + } + } +} + +// withConnWriteTimeout sets the timeout for write operations (subscribe, unsubscribe, pong). +func withConnWriteTimeout(d time.Duration) wsConnectionOption { + return func(o *wsConnectionOptions) { + if d > 0 { + o.writeTimeout = d + } + } +} + +// withOnEmpty sets a callback invoked when the last subscription is removed or the connection shuts down. +func withOnEmpty(f func()) wsConnectionOption { + return func(o *wsConnectionOptions) { + o.onEmpty = f + } +} + +type wsConnection struct { + conn *websocket.Conn + protocol protocol.Protocol + log abstractlogger.Logger + + // cancel cancels the connection-scoped context, unblocking readLoop and + // any in-flight writes. It is called exactly once inside shutdown(). + cancel context.CancelFunc + ctx context.Context + + subsMu sync.RWMutex + subs map[string]chan<- *common.Message + + closed atomic.Bool + + onEmpty func() + + writeTimeout time.Duration + + // Ping/pong tracking for client-initiated heartbeats. + // Values stored as UnixNano timestamps. + lastPingSentAt atomic.Int64 + lastPongAt atomic.Int64 +} + +func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts ...wsConnectionOption) *wsConnection { + o := wsConnectionOptions{ + logger: abstractlogger.NoopLogger, + writeTimeout: defaultWriteTimeout, + } + for _, apply := range opts { + apply(&o) + } + + ctx, cancel := context.WithCancel(context.Background()) + + c := &wsConnection{ + conn: conn, + protocol: proto, + log: o.logger, + cancel: cancel, + ctx: ctx, + subs: make(map[string]chan<- *common.Message), + onEmpty: o.onEmpty, + + writeTimeout: o.writeTimeout, + } + + c.lastPongAt.Store(time.Now().UnixNano()) + + return c +} + +func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request) (<-chan *common.Message, func(), error) { + if c.closed.Load() { + return nil, nil, common.ErrConnectionClosed + } + + // Small buffer to absorb bursts + ch := make(chan *common.Message, 8) + + c.subsMu.Lock() + + if _, exists := c.subs[id]; exists { + c.subsMu.Unlock() + return nil, nil, ErrSubscriptionExists + } + + c.subs[id] = ch + c.subsMu.Unlock() + + if err := c.protocol.Subscribe(ctx, c.conn, id, req); err != nil { + c.log.Error("wsConnection.Subscribe", + abstractlogger.String("id", id), + abstractlogger.Error(err), + ) + c.removeSub(id) + return nil, nil, err + } + + c.log.Debug("wsConnection.Subscribe", + abstractlogger.String("id", id), + abstractlogger.String("status", "subscribed"), + ) + + cancel := func() { c.unsubscribe(id) } + + return ch, cancel, nil +} + +func (c *wsConnection) removeSub(id string) { + c.subsMu.Lock() + ch, exists := c.subs[id] + delete(c.subs, id) + isEmpty := len(c.subs) == 0 + c.subsMu.Unlock() + + if exists { + close(ch) + } + + if isEmpty { + c.closeConn() + } +} + +func (c *wsConnection) unsubscribe(id string) { + c.subsMu.Lock() + _, exists := c.subs[id] + c.subsMu.Unlock() + + if !exists { + return + } + + c.log.Debug("wsConnection.unsubscribe", abstractlogger.String("id", id)) + + unsubscribeCtx, cancel := context.WithTimeout(context.Background(), c.writeTimeout) + defer cancel() + + _ = c.protocol.Unsubscribe(unsubscribeCtx, c.conn, id) + + c.removeSub(id) +} + +func (c *wsConnection) readLoop() { + defer c.shutdown(errors.New("read loop exited")) + + for { + if c.closed.Load() { + return + } + + msg, err := c.protocol.Read(c.ctx, c.conn) + if err != nil { + c.log.Debug("wsConnection.ReadLoop", + abstractlogger.String("status", "error"), + abstractlogger.Error(err), + ) + c.shutdown(fmt.Errorf("%w: read: %w", common.ErrConnectionClosed, err)) + return + } + + switch msg.Type { + case protocol.MessagePing: + c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "ping")) + pongCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) + _ = c.protocol.Pong(pongCtx, c.conn) + cancel() + case protocol.MessagePong: + c.lastPongAt.Store(time.Now().UnixNano()) + c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "pong")) + case protocol.MessageData, protocol.MessageError, protocol.MessageComplete: + c.dispatch(msg) + } + } +} + +func (c *wsConnection) dispatch(msg *protocol.Message) { + c.subsMu.RLock() + ch, exists := c.subs[msg.ID] + c.subsMu.RUnlock() + + if !exists { + return + } + + ch <- msg.IntoClientMessage() + + if msg.Type == protocol.MessageComplete || msg.Type == protocol.MessageError { + c.unsubscribe(msg.ID) + } +} + +func (c *wsConnection) shutdown(err error) { + if !c.closed.CompareAndSwap(false, true) { + return + } + + c.log.Debug("wsConnection.shutdown", + abstractlogger.Error(err), + ) + + c.conn.Close(websocket.StatusNormalClosure, "shutdown") + + c.subsMu.Lock() + subs := c.subs + c.subs = make(map[string]chan<- *common.Message) + c.subsMu.Unlock() + + errMsg := &common.Message{Err: err, Done: true} + for _, ch := range subs { + select { + case ch <- errMsg: + case <-time.After(100 * time.Millisecond): + // dead consumer + } + close(ch) + } + + // Cancel after dispatching errors so readLoop consumers still have a live + // context when they receive the error message. + c.cancel() + + if c.onEmpty != nil { + c.onEmpty() + } +} + +func (c *wsConnection) closeConn() { + c.shutdown(common.ErrConnectionClosed) +} + +// writeTimeoutDuration returns the configured write timeout. +func (c *wsConnection) writeTimeoutDuration() time.Duration { + return c.writeTimeout +} + +func (c *wsConnection) subCount() int { + c.subsMu.RLock() + defer c.subsMu.RUnlock() + return len(c.subs) +} + +// sendPing sends a protocol-level ping message and records the timestamp. +func (c *wsConnection) sendPing(timeout time.Duration) error { + pingCtx, cancel := context.WithTimeout(c.ctx, timeout) + defer cancel() + + err := c.protocol.Ping(pingCtx, c.conn) + if err != nil { + return err + } + + c.lastPingSentAt.Store(time.Now().UnixNano()) + return nil +} + +// pongOverdue returns true if a pong has not been received since the last ping +// and the ping timeout has elapsed. +func (c *wsConnection) pongOverdue(timeout time.Duration) bool { + pingSent := c.lastPingSentAt.Load() + if pingSent == 0 { + return false + } + return c.lastPongAt.Load() < pingSent && time.Since(time.Unix(0, pingSent)) > timeout +} + +func (c *wsConnection) isClosed() bool { + return c.closed.Load() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go new file mode 100644 index 0000000000..f21fc14911 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -0,0 +1,649 @@ +package transport + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" +) + +func TestWSConnection_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("returns channel and calls protocol subscribe", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ + Query: "subscription { test }", + }) + defer cancel() + + require.NoError(t, err) + assert.NotNil(t, ch) + assert.Len(t, proto.SubscribeCalls(), 1) + assert.Equal(t, "sub-1", proto.SubscribeCalls()[0].ID) + }) + + t.Run("returns error for duplicate subscription id", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + defer cancel() + + _, _, err = wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + + assert.ErrorIs(t, err, ErrSubscriptionExists) + }) + + t.Run("returns error when connection is closed", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + wsc.closeConn() + + _, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + + assert.ErrorIs(t, err, common.ErrConnectionClosed) + }) + + t.Run("returns error when protocol subscribe fails", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.subscribeErr = assert.AnError + wsc := newWSConnection(conn, proto) + + _, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + + assert.Error(t, err) + assert.Equal(t, 0, wsc.subCount(), "failed subscription should not be registered") + }) +} + +func TestWSConnection_ReadLoop(t *testing.T) { + t.Parallel() + + t.Run("dispatches data message to subscription channel", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + proto.PushMessage(&protocol.Message{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"value": 42}`)}, + }) + + msg := receiveWithTimeout(t, ch, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("closes channel on complete message", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + + go wsc.readLoop() + + proto.PushMessage(&protocol.Message{ + ID: "sub-1", + Type: protocol.MessageComplete, + }) + + // Consume the message (blocking send requires consumer) + msg := receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + + assertChannelClosed(t, ch) + }) + + t.Run("responds to ping with pong", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + go wsc.readLoop() + + proto.PushMessage(&protocol.Message{Type: protocol.MessagePing}) + + assert.Eventually(t, func() bool { + return proto.PongCount() > 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("ignores messages for unknown subscription ids", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + proto.PushMessage(&protocol.Message{ + ID: "unknown-sub", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"wrong": true}`)}, + }) + + proto.PushMessage(&protocol.Message{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"right": true}`)}, + }) + + msg := receiveWithTimeout(t, ch, time.Second) + assert.Contains(t, string(msg.Payload.Data), "right") + }) +} + +func TestWSConnection_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("calls protocol unsubscribe and closes channel", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + + cancel() + + assert.Len(t, proto.UnsubscribeCalls(), 1) + assert.Equal(t, "sub-1", proto.UnsubscribeCalls()[0]) + assertChannelClosed(t, ch) + }) + + t.Run("is idempotent", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + + cancel() + cancel() + cancel() + + assert.Len(t, proto.UnsubscribeCalls(), 1) + }) + + t.Run("times out using WriteTimeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.unsubscribeDelay = 500 * time.Millisecond + wsc := newWSConnection(conn, proto, + withConnWriteTimeout(50*time.Millisecond), + ) + + _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + + start := time.Now() + cancel() + elapsed := time.Since(start) + + assert.Less(t, elapsed, 200*time.Millisecond) + }) +} + +func TestWSConnection_OnEmpty(t *testing.T) { + t.Parallel() + + t.Run("calls callback when last subscription removed", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newWSConnection(conn, proto, + withOnEmpty(func() { emptyCalled <- struct{}{} }), + ) + + _, cancel, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty callback not called") + } + + assert.True(t, wsc.isClosed(), "connection should be closed after last subscription removed") + }) + + t.Run("does not call callback when subscriptions remain", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newWSConnection(conn, proto, + withOnEmpty(func() { emptyCalled <- struct{}{} }), + ) + + _, cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + _, cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + + cancel1() + + select { + case <-emptyCalled: + t.Error("onEmpty should not be called when subscriptions remain") + case <-time.After(100 * time.Millisecond): + // success + } + + cancel2() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty should be called after last subscription removed") + } + }) + + t.Run("calls callback on direct Close", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newWSConnection(conn, proto, + withOnEmpty(func() { emptyCalled <- struct{}{} }), + ) + + wsc.closeConn() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty callback not called on Close") + } + }) + + t.Run("calls callback on read loop exit", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newWSConnection(conn, proto, + withOnEmpty(func() { emptyCalled <- struct{}{} }), + ) + + go wsc.readLoop() + + // Close the connection to cause the read loop to exit + wsc.closeConn() + + select { + case <-emptyCalled: + // success + case <-time.After(time.Second): + t.Error("onEmpty callback not called on read loop exit") + } + }) +} + +func TestWSConnection_Close(t *testing.T) { + t.Parallel() + + t.Run("notifies all subscriptions with error", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + ch1, _, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + ch2, _, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + + wsc.closeConn() + + // Consume messages (blocking send requires consumer) + msg1 := receiveWithTimeout(t, ch1, 100*time.Millisecond) + assert.Error(t, msg1.Err) + assert.True(t, msg1.Done) + + msg2 := receiveWithTimeout(t, ch2, 100*time.Millisecond) + assert.Error(t, msg2.Err) + assert.True(t, msg2.Done) + + assertChannelClosed(t, ch1) + assertChannelClosed(t, ch2) + }) + + t.Run("is idempotent", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + assert.NotPanics(t, func() { + wsc.closeConn() + wsc.closeConn() + wsc.closeConn() + }) + }) +} + +func TestWSConnection_SubCount(t *testing.T) { + t.Parallel() + + t.Run("tracks subscription count accurately", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + assert.Equal(t, 0, wsc.subCount()) + + _, cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + assert.Equal(t, 1, wsc.subCount()) + + _, cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + assert.Equal(t, 2, wsc.subCount()) + + cancel1() + assert.Equal(t, 1, wsc.subCount()) + + cancel2() + assert.Equal(t, 0, wsc.subCount()) + }) +} + +func TestWSConnection_WriteTimeout(t *testing.T) { + t.Parallel() + + t.Run("pong write respects WriteTimeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.pongDelay = 500 * time.Millisecond + wsc := newWSConnection(conn, proto, + withConnWriteTimeout(50*time.Millisecond), + ) + + ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + // Send ping (will trigger slow pong) + proto.PushMessage(&protocol.Message{Type: protocol.MessagePing}) + + // Send data message right after + proto.PushMessage(&protocol.Message{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"test": true}`)}, + }) + + // Should receive data within timeout + small buffer + // If pong blocked for 500ms, this would timeout + msg := receiveWithTimeout(t, ch, 150*time.Millisecond) + assert.NotNil(t, msg.Payload) + }) +} + +func TestWSConnection_Defaults(t *testing.T) { + t.Parallel() + + t.Run("applies default write timeout when omitted", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto) + + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + }) + + t.Run("applies default write timeout for zero value", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto, + withConnWriteTimeout(0), + ) + + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + }) + + t.Run("overrides write timeout when provided", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto, + withConnWriteTimeout(10*time.Second), + ) + + assert.Equal(t, 10*time.Second, wsc.writeTimeoutDuration()) + }) + + t.Run("ignores negative write timeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto, + withConnWriteTimeout(-1*time.Second), + ) + + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + }) +} + +// Test helpers + +func newTestConn(t *testing.T) (*websocket.Conn, *websocket.Conn) { + t.Helper() + + serverConn := make(chan *websocket.Conn, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + serverConn <- conn + + for { + _, _, err := conn.Read(r.Context()) + if err != nil { + conn.Close(websocket.StatusNormalClosure, "shutdown") + return + } + } + })) + + t.Cleanup(server.Close) + + url := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := websocket.Dial(context.Background(), url, nil) //nolint:bodyclose + require.NoError(t, err) + + t.Cleanup(func() { + clientConn.Close(websocket.StatusNormalClosure, "shutdown") + }) + + srvConn := <-serverConn + t.Cleanup(func() { _ = srvConn.CloseNow() }) + + return clientConn, srvConn +} + +func assertChannelClosed(t *testing.T, ch <-chan *common.Message) { + t.Helper() + select { + case _, ok := <-ch: + assert.False(t, ok, "channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Error("timeout waiting for channel to close") + } +} + +// mockProtocol implements protocol.Protocol for testing. +type mockProtocol struct { + mu sync.Mutex + subscribeCalls []subscribeCall + unsubCalls []string + pongCount int + subscribeErr error + unsubscribeDelay time.Duration + pongDelay time.Duration + + messages chan *protocol.Message +} + +type subscribeCall struct { + ID string + Req *common.Request +} + +func newMockProtocol() *mockProtocol { + return &mockProtocol{ + messages: make(chan *protocol.Message, 100), + } +} + +func (m *mockProtocol) Subprotocol() string { return "graphql-transport-ws" } + +func (m *mockProtocol) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + return nil +} + +func (m *mockProtocol) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + m.mu.Lock() + defer m.mu.Unlock() + m.subscribeCalls = append(m.subscribeCalls, subscribeCall{ID: id, Req: req}) + return m.subscribeErr +} + +func (m *mockProtocol) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + if m.unsubscribeDelay > 0 { + select { + case <-time.After(m.unsubscribeDelay): + case <-ctx.Done(): + return ctx.Err() + } + } + + m.mu.Lock() + defer m.mu.Unlock() + m.unsubCalls = append(m.unsubCalls, id) + return nil +} + +func (m *mockProtocol) Read(ctx context.Context, conn *websocket.Conn) (*protocol.Message, error) { + select { + case msg := <-m.messages: + return msg, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *mockProtocol) Ping(ctx context.Context, conn *websocket.Conn) error { + return nil +} + +func (m *mockProtocol) Pong(ctx context.Context, conn *websocket.Conn) error { + if m.pongDelay > 0 { + select { + case <-time.After(m.pongDelay): + case <-ctx.Done(): + return ctx.Err() + } + } + + m.mu.Lock() + defer m.mu.Unlock() + m.pongCount++ + return nil +} + +func (m *mockProtocol) PushMessage(msg *protocol.Message) { + m.messages <- msg +} + +func (m *mockProtocol) SubscribeCalls() []subscribeCall { + m.mu.Lock() + defer m.mu.Unlock() + return append([]subscribeCall{}, m.subscribeCalls...) +} + +func (m *mockProtocol) UnsubscribeCalls() []string { + m.mu.Lock() + defer m.mu.Unlock() + return append([]string{}, m.unsubCalls...) +} + +func (m *mockProtocol) PongCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.pongCount +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go new file mode 100644 index 0000000000..73b6932d6d --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -0,0 +1,395 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + "github.com/rs/xid" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" +) + +type ErrFailedUpgrade struct { + URL string + StatusCode int +} + +func (e ErrFailedUpgrade) Error() string { + return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", e.URL, e.StatusCode) +} + +type ErrInvalidSubprotocol string + +func (e ErrInvalidSubprotocol) Error() string { + return fmt.Sprintf("provided websocket subprotocol '%s' is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subscriptions with the mentioned subprotocols", string(e)) +} + +type wsTransportOptions struct { + upgradeClient *http.Client + logger abstractlogger.Logger + pingInterval time.Duration + pingTimeout time.Duration + ackTimeout time.Duration + writeTimeout time.Duration + readLimit int64 +} + +// WSTransportOption configures a WSTransport. +type WSTransportOption func(*wsTransportOptions) + +// WithUpgradeClient sets the HTTP client used for WebSocket upgrade requests. +func WithUpgradeClient(c *http.Client) WSTransportOption { + return func(o *wsTransportOptions) { + if c != nil { + o.upgradeClient = c + } + } +} + +// WithLogger sets the logger for transport-level debug output. +func WithLogger(l abstractlogger.Logger) WSTransportOption { + return func(o *wsTransportOptions) { + if l != nil { + o.logger = l + } + } +} + +// WithPingInterval sets how often protocol-level pings are sent to all connections. +// Zero disables pinging. +func WithPingInterval(d time.Duration) WSTransportOption { + return func(o *wsTransportOptions) { + if d > 0 { + o.pingInterval = d + } + } +} + +// WithPingTimeout sets how long a connection may go without a pong before being closed. +// Zero disables the timeout (pings are sent but unresponsive connections are not killed). +func WithPingTimeout(d time.Duration) WSTransportOption { + return func(o *wsTransportOptions) { + if d > 0 { + o.pingTimeout = d + } + } +} + +// WithAckTimeout sets the maximum time to wait for a connection_ack after sending +// connection_init. Zero uses the protocol default (30s). +func WithAckTimeout(d time.Duration) WSTransportOption { + return func(o *wsTransportOptions) { + if d > 0 { + o.ackTimeout = d + } + } +} + +// WithWriteTimeout sets the timeout for WebSocket write operations on new connections. +// Zero uses defaultWriteTimeout (5s) at the connection level. +func WithWriteTimeout(d time.Duration) WSTransportOption { + return func(o *wsTransportOptions) { + if d > 0 { + o.writeTimeout = d + } + } +} + +// WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. +// Zero uses the defaultReadLimit (1MB). +func WithReadLimit(n int64) WSTransportOption { + return func(o *wsTransportOptions) { + if n > 0 { + o.readLimit = n + } + } +} + +type WSTransport struct { + ctx context.Context + opts wsTransportOptions + + mu sync.Mutex + dialing map[uint64]*dialResult + conns map[uint64]*wsConnection +} + +type dialResult struct { + done chan struct{} + conn *wsConnection + err error +} + +// NewWSTransport creates a new WSTransport. Connections are not closed when ctx +// is cancelled; instead they close themselves when their last subscriber is +// removed via the resolver's drain chain. The ping loop exits on ctx cancellation. +// +// If WithPingInterval is set, a single goroutine sends protocol-level pings to all +// connections at that cadence. If WithPingTimeout is also set, connections that fail +// to respond with a pong within that window are shut down. +func NewWSTransport(ctx context.Context, opts ...WSTransportOption) *WSTransport { + o := wsTransportOptions{ + upgradeClient: http.DefaultClient, + logger: abstractlogger.NoopLogger, + readLimit: defaultReadLimit, + } + for _, apply := range opts { + apply(&o) + } + + t := &WSTransport{ + ctx: ctx, + opts: o, + conns: make(map[uint64]*wsConnection), + dialing: make(map[uint64]*dialResult), + } + + if o.pingInterval > 0 { + go t.pingLoop() + } + + return t +} + +func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { + conn, err := t.getOrDial(ctx, opts) + if err != nil { + return nil, nil, err + } + + id := xid.New().String() + return conn.subscribe(ctx, id, req) +} + +// pingLoop sends periodic pings to all active connections and shuts down +// any that have not responded with a pong in time. +func (t *WSTransport) pingLoop() { + tick := time.Tick(t.opts.pingInterval) + for { + select { + case <-t.ctx.Done(): + return + case <-tick: + t.mu.Lock() + conns := make([]*wsConnection, 0, len(t.conns)) + for _, conn := range t.conns { + conns = append(conns, conn) + } + t.mu.Unlock() + + for _, conn := range conns { + if conn.isClosed() { + continue + } + + if t.opts.pingTimeout > 0 && conn.pongOverdue(t.opts.pingTimeout) { + t.opts.logger.Debug("wsTransport.pingLoop", + abstractlogger.String("action", "pong_timeout"), + ) + conn.closeConn() + continue + } + + if err := conn.sendPing(defaultWriteTimeout); err != nil { + t.opts.logger.Debug("wsTransport.pingLoop", + abstractlogger.String("action", "ping_failed"), + abstractlogger.Error(err), + ) + } + } + } + } +} + +// ReadLimit returns the configured read limit. +func (t *WSTransport) ReadLimit() int64 { + return t.opts.readLimit +} + +// WriteTimeout returns the configured write timeout for new connections. +func (t *WSTransport) WriteTimeout() time.Duration { + return t.opts.writeTimeout +} + +func (t *WSTransport) ConnCount() int { + t.mu.Lock() + defer t.mu.Unlock() + + return len(t.conns) +} + +func (t *WSTransport) getOrDial(ctx context.Context, opts common.Options) (*wsConnection, error) { + key := connKey(opts) + + t.mu.Lock() + + if conn, ok := t.conns[key]; ok && !conn.isClosed() { + t.mu.Unlock() + return conn, nil + } + + if result, ok := t.dialing[key]; ok { + t.mu.Unlock() + <-result.done + + if result.err != nil { + return nil, result.err + } + + return result.conn, nil + } + + result := &dialResult{done: make(chan struct{})} + t.dialing[key] = result + t.mu.Unlock() + + conn, err := t.dial(ctx, key, opts) + + result.conn = conn + result.err = err + close(result.done) + + t.mu.Lock() + delete(t.dialing, key) + + if err == nil { + t.conns[key] = conn + } + t.mu.Unlock() + + return conn, err +} + +func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) (*wsConnection, error) { + t.opts.logger.Debug("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("subprotocol", string(opts.WSSubprotocol)), + ) + + wsConn, resp, err := websocket.Dial(ctx, opts.Endpoint, &websocket.DialOptions{ //nolint:bodyclose + HTTPClient: t.opts.upgradeClient, + Subprotocols: opts.WSSubprotocol.Subprotocols(), + HTTPHeader: opts.Headers, + }) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } + + t.opts.logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Error(err), + ) + + // backwards compatibility with error handling in the router + if resp != nil && resp.StatusCode != http.StatusSwitchingProtocols { + return nil, ErrFailedUpgrade{URL: opts.Endpoint, StatusCode: resp.StatusCode} + } + + return nil, err + } + + wsConn.SetReadLimit(t.opts.readLimit) + + proto, err := t.negotiateSubprotocol(opts.WSSubprotocol, wsConn.Subprotocol()) + if err != nil { + t.opts.logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("error", "subprotocol negotiation failed"), + abstractlogger.Error(err), + ) + wsConn.Close(websocket.StatusProtocolError, err.Error()) + return nil, err + } + + if err := proto.Init(ctx, wsConn, opts.InitPayload); err != nil { + t.opts.logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("error", "protocol init failed"), + abstractlogger.Error(err), + ) + wsConn.Close(websocket.StatusProtocolError, "init failed") + return nil, err + } + + t.opts.logger.Debug("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("status", "connected"), + abstractlogger.String("negotiated_subprotocol", wsConn.Subprotocol()), + ) + + conn := newWSConnection(wsConn, proto, + withConnLogger(t.opts.logger), + withConnWriteTimeout(t.opts.writeTimeout), + withOnEmpty(func() { t.removeConn(key) }), + ) + + go conn.readLoop() + + return conn, nil +} + +func (t *WSTransport) negotiateSubprotocol(requested common.WSSubprotocol, accepted string) (protocol.Protocol, error) { + if requested != common.SubprotocolAuto { + if accepted != string(requested) { + return nil, ErrInvalidSubprotocol(accepted) + } + } + + switch common.WSSubprotocol(accepted) { + case common.SubprotocolGraphQLTransportWS: + p := protocol.NewGraphQLTransportWS() + if t.opts.ackTimeout > 0 { + p.AckTimeout = t.opts.ackTimeout + } + return p, nil + case common.SubprotocolGraphQLWS: + p := protocol.NewGraphQLWS() + if t.opts.ackTimeout > 0 { + p.AckTimeout = t.opts.ackTimeout + } + return p, nil + default: + return nil, ErrInvalidSubprotocol(accepted) + } +} + +func (t *WSTransport) removeConn(key uint64) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.conns, key) +} + +// connKey computes a hash key for connection pooling. +func connKey(opts common.Options) uint64 { + h := pool.Hash64.Get() + defer pool.Hash64.Put(h) + + _, _ = h.WriteString(opts.Endpoint) + _, _ = h.WriteString("\x00") + + _, _ = h.WriteString(string(opts.WSSubprotocol)) + _, _ = h.WriteString("\x00") + + if len(opts.Headers) > 0 { + _ = opts.Headers.Write(h) + } + _, _ = h.WriteString("\x00") + + if len(opts.InitPayload) > 0 { + if data, err := json.Marshal(opts.InitPayload); err == nil { + _, _ = h.Write(data) + } + } + + return h.Sum64() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go new file mode 100644 index 0000000000..414d17124d --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -0,0 +1,1159 @@ +package transport + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestWSTransport_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("dials and returns message channel", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + assert.Equal(t, "subscribe", msg["type"]) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 42}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := NewWSTransport(t.Context()) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.Contains(t, string(msg.Payload.Data), "42") + + msg = receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("reuses connection for same endpoint", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := NewWSTransport(t.Context()) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + defer cancel1() + + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + defer cancel2() + + // Both should receive messages + receiveWithTimeout(t, ch1, time.Second) + receiveWithTimeout(t, ch2, time.Second) + + // Only one connection should have been made + assert.Equal(t, int32(1), dialCount.Load()) + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("creates new connection for different headers", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := NewWSTransport(t.Context()) + + headers1 := http.Header{"Authorization": []string{"Bearer token1"}} + headers2 := http.Header{"Authorization": []string{"Bearer token2"}} + + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers1, + }) + require.NoError(t, err) + defer cancel1() + + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers2, + }) + require.NoError(t, err) + defer cancel2() + + receiveWithTimeout(t, ch1, time.Second) + receiveWithTimeout(t, ch2, time.Second) + + // Two connections due to different headers + assert.Equal(t, int32(2), dialCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + }) + + t.Run("creates new connection for different init payload", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := NewWSTransport(t.Context()) + + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + InitPayload: map[string]any{"token": "abc"}, + }) + require.NoError(t, err) + defer cancel1() + + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + InitPayload: map[string]any{"token": "xyz"}, + }) + require.NoError(t, err) + defer cancel2() + + receiveWithTimeout(t, ch1, time.Second) + receiveWithTimeout(t, ch2, time.Second) + + // Two connections due to different init payload + assert.Equal(t, int32(2), dialCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + }) + + t.Run("removes connection when all subscriptions closed", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := NewWSTransport(t.Context()) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + _, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + + _, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel1() + assert.Equal(t, 1, tr.ConnCount()) // still has one subscription + + cancel2() + + // Wait for onEmpty callback + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("redials after connection closed", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := NewWSTransport(t.Context()) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + // First subscription + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + receiveWithTimeout(t, ch1, time.Second) + cancel1() + + // Wait for connection to be removed + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + + // Second subscription should redial + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + defer cancel2() + receiveWithTimeout(t, ch2, time.Second) + + assert.Equal(t, int32(2), dialCount.Load()) + }) +} + +func TestWSTransport_SubscriberDrain(t *testing.T) { + t.Parallel() + + t.Run("connection closes when last subscriber cancels", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := NewWSTransport(t.Context()) + + _, cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel() + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("connection stays open while subscribers remain", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := NewWSTransport(t.Context()) + + opts := common.Options{Endpoint: server.URL, Transport: common.TransportWS} + + _, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + require.NoError(t, err) + + _, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel1() + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount(), "connection should stay open with remaining subscriber") + + cancel2() + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +func TestWSTransport_ConcurrentSubscribe(t *testing.T) { + t.Parallel() + + t.Run("handles concurrent subscribes to same endpoint", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := NewWSTransport(t.Context()) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + var wg sync.WaitGroup + for range 10 { + wg.Go(func() { + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { test }"}, opts) + if err != nil { + return + } + defer cancel() + + receiveWithTimeout(t, ch, time.Second) + }) + } + + wg.Wait() + + // Should have only dialed once (or maybe twice due to race, but not 10 times) + assert.LessOrEqual(t, dialCount.Load(), int32(2)) + }) +} + +func TestWSTransport_InitPayloadForwarding(t *testing.T) { + t.Parallel() + + t.Run("forwards init payload to server with graphql-transport-ws protocol", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe and respond + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subMsg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := NewWSTransport(t.Context()) + + initPayload := map[string]any{ + "Authorization": "Bearer secret-token", + "X-Custom": "custom-value", + "nested": map[string]any{ + "key": "nested-value", + }, + } + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: initPayload, + }) + require.NoError(t, err) + defer cancel() + + // Verify payload was received by server + select { + case payload := <-receivedPayload: + require.NotNil(t, payload, "server should receive init payload") + assert.Equal(t, "Bearer secret-token", payload["Authorization"]) + assert.Equal(t, "custom-value", payload["X-Custom"]) + nested, ok := payload["nested"].(map[string]any) + require.True(t, ok, "nested should be a map") + assert.Equal(t, "nested-value", nested["key"]) + case <-time.After(time.Second): + t.Fatal("timeout waiting for init payload") + } + + // Subscription should work + msg := receiveWithTimeout(t, ch, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("forwards init payload to server with graphql-ws legacy protocol", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read start and respond + var startMsg map[string]any + if err := wsjson.Read(ctx, conn, &startMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": startMsg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := NewWSTransport(t.Context()) + + initPayload := map[string]any{ + "token": "legacy-auth-token", + "version": float64(2), // JSON numbers are float64 + } + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, // Legacy protocol + InitPayload: initPayload, + }) + require.NoError(t, err) + defer cancel() + + // Verify payload was received by server + select { + case payload := <-receivedPayload: + require.NotNil(t, payload, "server should receive init payload") + assert.Equal(t, "legacy-auth-token", payload["token"]) + assert.Equal(t, float64(2), payload["version"]) + case <-time.After(time.Second): + t.Fatal("timeout waiting for init payload") + } + + // Subscription should work + msg := receiveWithTimeout(t, ch, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("sends empty payload when init payload is nil", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe and respond + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subMsg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := NewWSTransport(t.Context()) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: nil, // No init payload + }) + require.NoError(t, err) + defer cancel() + + // Server should receive nil/empty payload + select { + case payload := <-receivedPayload: + assert.Nil(t, payload, "server should receive nil payload when not provided") + case <-time.After(time.Second): + t.Fatal("timeout waiting for init message") + } + + // Subscription should still work + msg := receiveWithTimeout(t, ch, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("same endpoint with different init payloads uses separate connections", func(t *testing.T) { + t.Parallel() + + var mu sync.Mutex + receivedPayloads := make([]map[string]any, 0) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + + mu.Lock() + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayloads = append(receivedPayloads, payload) + } + mu.Unlock() + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Handle subscriptions + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + })) + t.Cleanup(server.Close) + + tr := NewWSTransport(t.Context()) + + // First subscription with user1 token + ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: map[string]any{"user": "user1"}, + }) + require.NoError(t, err) + defer cancel1() + + receiveWithTimeout(t, ch1, time.Second) + + // Second subscription with user2 token - should create new connection + ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: map[string]any{"user": "user2"}, + }) + require.NoError(t, err) + defer cancel2() + + receiveWithTimeout(t, ch2, time.Second) + + // Verify two separate connections were made with different payloads + assert.Equal(t, 2, tr.ConnCount()) + + mu.Lock() + defer mu.Unlock() + require.Len(t, receivedPayloads, 2) + + users := make([]string, 0, 2) + for _, p := range receivedPayloads { + if user, ok := p["user"].(string); ok { + users = append(users, user) + } + } + assert.ElementsMatch(t, []string{"user1", "user2"}, users) + }) +} + +func TestWSTransport_LegacyProtocol(t *testing.T) { + t.Parallel() + + t.Run("connects to legacy graphql-ws server", func(t *testing.T) { + t.Parallel() + + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read start message + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + assert.Equal(t, "start", msg["type"]) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 42}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := NewWSTransport(t.Context()) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, // Request legacy protocol + }) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.Contains(t, string(msg.Payload.Data), "42") + + msg = receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("handles keep-alive messages", func(t *testing.T) { + t.Parallel() + + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read start message + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + + // Send keep-alive + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := NewWSTransport(t.Context()) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }) + require.NoError(t, err) + defer cancel() + + // Should receive data (keep-alive is handled internally) + msg := receiveWithTimeout(t, ch, time.Second) + assert.NotNil(t, msg.Payload) + + msg = receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + }) + + t.Run("auto-negotiates to legacy when modern unavailable", func(t *testing.T) { + t.Parallel() + + // Server only supports legacy protocol + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + assert.Equal(t, "start", msg["type"]) // Should use legacy message type + + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 99}}, + }) + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := NewWSTransport(t.Context()) + + ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolAuto, // Auto-negotiate + }) + require.NoError(t, err) + defer cancel() + + msg := receiveWithTimeout(t, ch, time.Second) + assert.Contains(t, string(msg.Payload.Data), "99") + }) +} + +func TestWSTransport_Heartbeat(t *testing.T) { + t.Parallel() + + t.Run("sends pings and receives pongs", func(t *testing.T) { + t.Parallel() + + var pingCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Keep connection alive, respond to pings with pongs + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + pingCount.Add(1) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "pong"}) + } + } + }) + + tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) + + _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + defer cancel() + + // Wait for at least 2 pings to be sent + assert.Eventually(t, func() bool { + return pingCount.Load() >= 2 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("closes connection on pong timeout", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read pings but never respond with pong + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + } + }) + + tr := NewWSTransport(t.Context(), WithPingInterval(100*time.Millisecond), WithPingTimeout(50*time.Millisecond)) + + ch, _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + + // Connection should be closed due to pong timeout, subscriber gets notified + msg := receiveWithTimeout(t, ch, time.Second) + assert.True(t, msg.Done) + assert.Error(t, msg.Err) + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("does not kill connection when ping timeout is disabled", func(t *testing.T) { + t.Parallel() + + var pingCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read pings but never respond with pong + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + pingCount.Add(1) + } + } + }) + + // PingInterval set, PingTimeout left at zero (disabled) + tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) + + _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + defer cancel() + + // Wait for several ping cycles without any pong responses + assert.Eventually(t, func() bool { + return pingCount.Load() >= 3 + }, time.Second, 10*time.Millisecond) + + // Connection must still be alive despite no pongs + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("keeps connection alive when pongs arrive", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Respond to pings with pongs + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "pong"}) + } + } + }) + + tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond), WithPingTimeout(200*time.Millisecond)) + + _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }) + require.NoError(t, err) + defer cancel() + + // Connection should remain alive after several ping cycles + time.Sleep(250 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount()) + }) +} + +func TestWSTransport_Defaults(t *testing.T) { + t.Parallel() + + t.Run("applies default read limit when omitted", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context()) + + assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + }) + + t.Run("applies default read limit for zero value", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context(), WithReadLimit(0)) + + assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + }) + + t.Run("overrides read limit when provided", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context(), WithReadLimit(2*1024*1024)) + + assert.Equal(t, int64(2*1024*1024), tr.ReadLimit()) + }) + + t.Run("ignores negative read limit", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context(), WithReadLimit(-1)) + + assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + }) + + t.Run("applies zero write timeout by default", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context()) + + // Zero means connections use their own DefaultWriteTimeout + assert.Equal(t, time.Duration(0), tr.WriteTimeout()) + }) + + t.Run("overrides write timeout when provided", func(t *testing.T) { + t.Parallel() + + tr := NewWSTransport(t.Context(), WithWriteTimeout(10*time.Second)) + + assert.Equal(t, 10*time.Second, tr.WriteTimeout()) + }) +} + +// Test helpers + +func newGraphQLWSServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + return server +} + +func newLegacyGraphQLWSServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, // Legacy protocol only + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + return server +} diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index b03bdd0781..be9169c1d3 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -18,11 +18,6 @@ type SubscriptionDataSource interface { Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } -type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error - AsyncStop(id uint64) -} - // HookableSubscriptionDataSource is a hookable interface for subscription data sources. // It is used to call a function when a subscription is started. // This is useful for data sources that need to do some work when a subscription is started, diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index f735752ef9..32230cf07a 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -10,11 +10,12 @@ import ( "io" "net/http" "runtime" + "sync" + "sync/atomic" "time" "github.com/buger/jsonparser" "github.com/pkg/errors" - "go.uber.org/atomic" "github.com/wundergraph/go-arena" @@ -29,7 +30,7 @@ const ( // ConnectionIDs is used to create unique connection IDs for each subscription // Whenever a new connection is created, use this to generate a new ID // It is public because it can be used in more high level packages to instantiate a new connection -var ConnectionIDs = atomic.NewInt64(0) +var ConnectionIDs atomic.Int64 type Reporter interface { // SubscriptionUpdateSent called when a new subscription update is sent @@ -48,19 +49,19 @@ type AsyncErrorWriter interface { WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) } -// Resolver is a single threaded event loop that processes all events on a single goroutine. -// It is absolutely critical to ensure that all events are processed quickly to prevent blocking -// and that resolver modifications are done on the event loop goroutine. Long-running operations -// should be offloaded to the subscription worker goroutine. If a different goroutine needs to emit -// an event, it should be done through the events channel to avoid race conditions. +// Resolver manages GraphQL subscriptions using a mutex-protected trigger registry. +// All trigger/subscription state is guarded by mu. Long-running I/O (writes to clients) +// is performed outside the lock using a snapshot-and-release pattern. type Resolver struct { ctx context.Context options ResolverOptions maxConcurrency chan struct{} - triggers map[uint64]*trigger - events chan subscriptionEvent - triggerUpdateBuf *bytes.Buffer + mu sync.Mutex + shutdown bool + triggers map[uint64]*trigger + subscriptionsByID map[SubscriptionIdentifier]*subscriptionState + subscriptionsByConnection map[int64]map[SubscriptionIdentifier]*subscriptionState allowedErrorExtensionFields map[string]struct{} allowedErrorFields map[string]struct{} @@ -70,7 +71,7 @@ type Resolver struct { propagateSubgraphErrors bool propagateSubgraphStatusCodes bool - // Subscription heartbeat interval + // Subscription heartbeat interval for periodic updater heartbeats. heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration @@ -264,11 +265,11 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { options: options, propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, - events: make(chan subscriptionEvent), triggers: make(map[uint64]*trigger), + subscriptionsByID: make(map[SubscriptionIdentifier]*subscriptionState), + subscriptionsByConnection: make(map[int64]map[SubscriptionIdentifier]*subscriptionState), reporter: options.Reporter, asyncErrorWriter: options.AsyncErrorWriter, - triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), allowedErrorExtensionFields: allowedExtensionFields, allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, @@ -283,7 +284,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } - go resolver.processEvents() + go resolver.heartbeatLoop() + context.AfterFunc(resolver.ctx, func() { + resolver.shutdownResolver() + }) return resolver } @@ -437,131 +441,78 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe } type trigger struct { + mu sync.RWMutex id uint64 cancel context.CancelFunc - subscriptions map[*Context]*sub + subscriptions map[SubscriptionIdentifier]*subscriptionState + updateBuf *bytes.Buffer // initialized is set to true when the trigger is started and initialized initialized bool updater *subscriptionUpdater } func (t *trigger) subscriptionIds() map[context.Context]SubscriptionIdentifier { + t.mu.RLock() + defer t.mu.RUnlock() + subs := make(map[context.Context]SubscriptionIdentifier, len(t.subscriptions)) - for ctx, sub := range t.subscriptions { - subs[ctx.Context()] = sub.id + for _, sub := range t.subscriptions { + subs[sub.ctx.Context()] = sub.id } return subs } -// workItem is used to encapsulate a function that needs to be -// executed in the worker goroutine. fn will be executed, and if -// final is true the worker will be stopped after fn is executed. -type workItem struct { - fn func() - final bool -} - -type sub struct { +type subscriptionState struct { + triggerID uint64 resolve *GraphQLSubscription - resolver *Resolver ctx *Context writer SubscriptionResponseWriter id SubscriptionIdentifier heartbeat bool completed chan struct{} - // workChan is used to send work to the writer goroutine. All work is processed sequentially. - workChan chan workItem + writeMu sync.Mutex + // removed guards against writes after the subscription has been removed. + // Uses CompareAndSwap to prevent double-close of the completed channel. + removed atomic.Bool + // lastWriteTime stores unix nanos of the last successful data write. + lastWriteTime atomic.Int64 } -// startWorker runs in its own goroutine to process fetches and write data to the client synchronously -// it also takes care of sending heartbeats to the client but only if the subscription supports it -// TODO implement a goroutine pool that is sharded by the subscription id to avoid creating a new goroutine for each subscription -func (s *sub) startWorker() { - if s.heartbeat { - s.startWorkerWithHeartbeat() - return - } - s.startWorkerWithoutHeartbeat() +type subscriptionFinalizer struct { + sub *subscriptionState } -// startWorkerWithHeartbeat is similar to startWorker but sends heartbeats to the client when enabled. -// It sends a heartbeat to the client every heartbeatInterval. Heartbeats are handled by the SubscriptionResponseWriter interface. -// TODO: Implement a shared timer implementation to avoid creating a new ticker for each subscription. -func (s *sub) startWorkerWithHeartbeat() { - heartbeatTicker := time.NewTicker(s.resolver.heartbeatInterval) - defer heartbeatTicker.Stop() - - for { - select { - case <-s.ctx.ctx.Done(): - // Complete when the client request context is done for synchronous subscriptions - s.close(SubscriptionCloseKindGoingAway) - - return - case <-s.resolver.ctx.Done(): - // Abort immediately if the resolver is shutting down - s.close(SubscriptionCloseKindGoingAway) - - return - case <-heartbeatTicker.C: - s.resolver.handleHeartbeat(s) - case work := <-s.workChan: - work.fn() - - if work.final { - return - } - - // Reset the heartbeat ticker after each write to avoid sending unnecessary heartbeats - heartbeatTicker.Reset(s.resolver.heartbeatInterval) - } +func runSubscriptionFinalizers(finalizers []subscriptionFinalizer) { + for _, f := range finalizers { + f.sub.done() } } -func (s *sub) startWorkerWithoutHeartbeat() { - for { - select { - case <-s.ctx.ctx.Done(): - // Complete when the client request context is done for synchronous subscriptions - s.close(SubscriptionCloseKindGoingAway) - - return - case <-s.resolver.ctx.Done(): - // Abort immediately if the resolver is shutting down - s.close(SubscriptionCloseKindGoingAway) - - return - case work := <-s.workChan: - work.fn() - - if work.final { - return - } - } - } +// done closes the completed channel to signal that the subscription is finished. +// It does not send any downstream messages — Complete/Error are sent separately. +func (s *subscriptionState) done() { + s.writeMu.Lock() + defer s.writeMu.Unlock() + close(s.completed) } -// Called when subgraph indicates a "complete" subscription -func (s *sub) complete() { - // The channel is used to communicate that the subscription is done - // It is used only in the synchronous subscription case and to avoid sending events - // to a subscription that is already done. - defer close(s.completed) - +// complete delivers a "subscription done" signal to the downstream writer. +// Called by handleTriggerComplete, not through finalizers. +func (s *subscriptionState) complete() { + s.writeMu.Lock() + defer s.writeMu.Unlock() s.writer.Complete() } -// Called when subgraph becomes unreachable or closes the connection without a "complete" event -func (s *sub) close(kind SubscriptionCloseKind) { - // The channel is used to communicate that the subscription is done - // It is used only in the synchronous subscription case and to avoid sending events - // to a subscription that is already done. - defer close(s.completed) - - s.writer.Close(kind) +// error delivers a terminal error payload to the downstream writer. +// Called by handleTriggerError, not through finalizers. +func (s *subscriptionState) error(data []byte) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.writer.Error(data) } -func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, sharedInput []byte) { +func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscriptionState, sharedInput []byte) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID) } @@ -580,7 +531,11 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeMu.Lock() + if !sub.removed.Load() { + r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + } + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) } @@ -592,7 +547,11 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeMu.Lock() + if !sub.removed.Load() { + r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + } + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) } @@ -602,9 +561,17 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + sub.writeMu.Lock() + if sub.removed.Load() { + sub.writeMu.Unlock() + r.resolveArenaPool.Release(resolveArena) + return + } + if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) } @@ -617,10 +584,13 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar r.resolveArenaPool.Release(resolveArena) if err := sub.writer.Flush(); err != nil { + sub.writeMu.Unlock() // If flush fails (e.g. client disconnected), remove the subscription. - _ = r.AsyncUnsubscribeSubscription(sub.id) + _ = r.UnsubscribeSubscription(sub.id) return } + sub.lastWriteTime.Store(time.Now().UnixNano()) + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:flushed:%d\n", sub.id.SubscriptionID) @@ -634,95 +604,34 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } } -// processEvents maintains the single threaded event loop that processes all events -func (r *Resolver) processEvents() { - done := r.ctx.Done() - - // events channel can't be closed here because producers are - // sending events across multiple goroutines - - for { - select { - case <-done: - r.handleShutdown() - return - case event := <-r.events: - r.handleEvent(event) - } - } -} - -// handleEvent is a single threaded function that processes events from the events channel -// All events are processed in the order they are received and need to be processed quickly -// to prevent blocking the event loop and any other events from being processed. -// TODO: consider using a worker pool that distributes events from different triggers to different workers -// to avoid blocking the event loop and improve performance. -func (r *Resolver) handleEvent(event subscriptionEvent) { - switch event.kind { - case subscriptionEventKindAddSubscription: - r.handleAddSubscription(event.triggerID, event.addSubscription) - case subscriptionEventKindRemoveSubscription: - r.handleRemoveSubscription(event.id) - case subscriptionEventKindCompleteSubscription: - r.handleCompleteSubscription(event.id) - case subscriptionEventKindRemoveClient: - r.handleRemoveClient(event.id.ConnectionID) - case subscriptionEventKindUpdateSubscription: - r.handleUpdateSubscription(event.triggerID, event.data, event.id) - case subscriptionEventKindTriggerUpdate: - r.handleTriggerUpdate(event.triggerID, event.data) - case subscriptionEventKindTriggerComplete: - r.handleTriggerComplete(event.triggerID) - case subscriptionEventKindTriggerInitialized: - r.handleTriggerInitialized(event.triggerID) - case subscriptionEventKindTriggerClose: - r.handleTriggerClose(event) - case subscriptionEventKindUnknown: - panic("unknown event") - } -} - -// handleHeartbeat sends a heartbeat to the client. It needs to be executed on the same goroutine as the writer. -func (r *Resolver) handleHeartbeat(sub *sub) { +func (r *Resolver) executeSubscriptionHeartbeat(sub *subscriptionState) { if r.options.Debug { - fmt.Printf("resolver:heartbeat\n") + fmt.Printf("resolver:heartbeat:subscription:%d\n", sub.id.SubscriptionID) } - if r.ctx.Err() != nil { + if r.ctx.Err() != nil || sub.ctx.Context().Err() != nil { return } - if sub.ctx.Context().Err() != nil { - return - } + sub.writeMu.Lock() - if r.options.Debug { - fmt.Printf("resolver:heartbeat:subscription:%d\n", sub.id.SubscriptionID) + if sub.removed.Load() { + sub.writeMu.Unlock() + return } if err := sub.writer.Heartbeat(); err != nil { - // If heartbeat fails (e.g. client disconnected), remove the subscription. - _ = r.AsyncUnsubscribeSubscription(sub.id) + sub.writeMu.Unlock() + _ = r.UnsubscribeSubscription(sub.id) return } - - if r.options.Debug { - fmt.Printf("resolver:heartbeat:subscription:done:%d\n", sub.id.SubscriptionID) - } + sub.writeMu.Unlock() if r.reporter != nil { r.reporter.SubscriptionUpdateSent() } } -func (r *Resolver) handleTriggerClose(s subscriptionEvent) { - if r.options.Debug { - fmt.Printf("resolver:trigger:shutdown:%d:%d\n", s.triggerID, s.id.SubscriptionID) - } - - r.closeTrigger(s.triggerID, s.closeKind) -} - func (r *Resolver) handleTriggerInitialized(triggerID uint64) { trig, ok := r.triggers[triggerID] if !ok { @@ -735,14 +644,6 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) { } } -func (r *Resolver) handleTriggerComplete(triggerID uint64) { - if r.options.Debug { - fmt.Printf("resolver:trigger:complete:%d\n", triggerID) - } - - r.completeTrigger(triggerID) -} - type StartupHookContext struct { Context context.Context Updater func(data []byte) @@ -754,8 +655,6 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti hookCtx := StartupHookContext{ Context: add.ctx.Context(), Updater: func(data []byte) { - // Writing on the updater channel is safe but has to happen outside of the event loop - // to respect order and not block the event loop updater.UpdateSubscription(add.id, data) }, } @@ -766,38 +665,53 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) } r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) - _ = r.AsyncUnsubscribeSubscription(add.id) + _ = r.UnsubscribeSubscription(add.id) return err } } return nil } +func (r *Resolver) addSubscriptionIndex(s *subscriptionState) { + id := s.id + r.subscriptionsByID[id] = s + byConn, ok := r.subscriptionsByConnection[id.ConnectionID] + if !ok { + byConn = make(map[SubscriptionIdentifier]*subscriptionState) + r.subscriptionsByConnection[id.ConnectionID] = byConn + } + byConn[id] = s +} + +func (r *Resolver) removeSubscriptionIndex(id SubscriptionIdentifier) { + delete(r.subscriptionsByID, id) + byConn, ok := r.subscriptionsByConnection[id.ConnectionID] + if !ok { + return + } + delete(byConn, id) + if len(byConn) == 0 { + delete(r.subscriptionsByConnection, id.ConnectionID) + } +} + +// handleAddSubscription must be called with r.mu held. func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { - var ( - err error - ) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID) } - s := &sub{ + s := &subscriptionState{ + triggerID: triggerID, ctx: add.ctx, resolve: add.resolve, writer: add.writer, id: add.id, completed: add.completed, - workChan: make(chan workItem, 32), - resolver: r, } - if add.ctx.ExecutionOptions.SendHeartbeat { s.heartbeat = true } - // Start the dedicated worker goroutine where the subscription updates are processed - // and writes are written to the client in a single threaded manner - go s.startWorker() - trig, ok := r.triggers[triggerID] if ok { if r.reporter != nil { @@ -806,17 +720,15 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:added:%d:%d\n", triggerID, add.id.SubscriptionID) } - // Execute the startup hooks in a separate goroutine to avoid blocking the event loop - s.workChan <- workItem{ - fn: func() { - _ = r.executeStartupHooks(add, trig.updater) - // if the startup hooks return an error, we don't have to do anything else - }, - final: false, - } - // After the startup hooks are executed, we can add the subscription to the subscriptions registry - // so that it can start receive events - trig.subscriptions[add.ctx] = s + // Add the subscription to the registry so it can receive events + trig.mu.Lock() + trig.subscriptions[add.id] = s + trig.mu.Unlock() + r.addSubscriptionIndex(s) + // Execute the startup hooks in a goroutine to avoid holding the lock + go func() { + _ = r.executeStartupHooks(add, trig.updater) + }() return } @@ -827,342 +739,435 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) updater := &subscriptionUpdater{ debug: r.options.Debug, triggerID: triggerID, - ch: r.events, + resolver: r, ctx: ctx, } cloneCtx := add.ctx.clone(ctx) trig = &trigger{ id: triggerID, - subscriptions: make(map[*Context]*sub), + subscriptions: make(map[SubscriptionIdentifier]*subscriptionState), + updateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), cancel: cancel, updater: updater, } r.triggers[triggerID] = trig - trig.subscriptions[add.ctx] = s + trig.mu.Lock() + trig.subscriptions[add.id] = s + trig.mu.Unlock() updater.subsFn = trig.subscriptionIds + r.addSubscriptionIndex(s) if r.reporter != nil { r.reporter.SubscriptionCountInc(1) } - var asyncDataSource AsyncSubscriptionDataSource - - if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { - trig.cancel = func() { - cancel() - async.AsyncStop(triggerID) - } - asyncDataSource = async - } - go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } // This is blocking so the startup hook can decide if a subscription should be started or not by returning an error - err = r.executeStartupHooks(add, trig.updater) + err := r.executeStartupHooks(add, trig.updater) if err != nil { return } - if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, trig.updater) - } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, trig.updater) - } + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, trig.updater) if err != nil { if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) } r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) - _ = r.emitTriggerClose(triggerID) + r.doneTriggerFromUpdater(triggerID) return } - _ = r.emitTriggerInitialized(triggerID) + r.markTriggerInitialized(triggerID) if r.options.Debug { fmt.Printf("resolver:trigger:started:%d\n", triggerID) } }() +} +// markTriggerInitialized marks a trigger as initialized under the lock. +func (r *Resolver) markTriggerInitialized(triggerID uint64) { + r.mu.Lock() + defer r.mu.Unlock() + r.handleTriggerInitialized(triggerID) } -func (r *Resolver) emitTriggerClose(triggerID uint64) error { +// doneTriggerFromUpdater performs cleanup for a trigger from a datasource/updater goroutine. +// It detaches the trigger, runs done finalizers (close completed channels), and cancels the trigger context. +func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { if r.options.Debug { fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID) } + r.mu.Lock() + removed, finalizers, cancel, initialized := r.detachTriggerLocked(triggerID) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removed) + if initialized { + r.reporter.TriggerCountDec(1) + } + } + r.mu.Unlock() + runSubscriptionFinalizers(finalizers) + if cancel != nil { + cancel() + } +} - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindTriggerClose, - closeKind: SubscriptionCloseKindNormal, - }: +// handleTriggerComplete delivers a complete signal to all subscriptions on the trigger. +// Does NOT detach the trigger — Done() does that. +func (r *Resolver) handleTriggerComplete(triggerID uint64) { + r.mu.Lock() + trig, ok := r.triggers[triggerID] + if !ok { + r.mu.Unlock() + return } + trig.mu.Lock() + subs := make([]*subscriptionState, 0, len(trig.subscriptions)) + for _, s := range trig.subscriptions { + subs = append(subs, s) + } + trig.mu.Unlock() + r.mu.Unlock() - return nil + for _, s := range subs { + if !s.removed.Load() { + s.complete() + } + } } -func (r *Resolver) emitTriggerInitialized(triggerID uint64) error { - if r.options.Debug { - fmt.Printf("resolver:trigger:initialized:%d\n", triggerID) +// handleTriggerError delivers a terminal error to all subscriptions on the trigger, +// bypassing the resolve pipeline. Does NOT detach the trigger — Done() does that. +func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { + r.mu.Lock() + trig, ok := r.triggers[triggerID] + if !ok { + r.mu.Unlock() + return } - - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindTriggerInitialized, - }: + trig.mu.Lock() + subs := make([]*subscriptionState, 0, len(trig.subscriptions)) + for _, s := range trig.subscriptions { + subs = append(subs, s) } + trig.mu.Unlock() + r.mu.Unlock() - return nil + for _, s := range subs { + if !s.removed.Load() { + s.error(data) + } + } } -func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) { +func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } - removed := 0 - for u := range r.triggers { - trig := r.triggers[u] - removed += r.completeTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool { - return sID == id - }) - if len(trig.subscriptions) == 0 { - r.completeTrigger(trig.id) - } - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - } + return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { +func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } - removed := 0 - for u := range r.triggers { - trig := r.triggers[u] - removed += r.closeTriggerSubscriptions(u, SubscriptionCloseKindNormal, func(sID SubscriptionIdentifier) bool { - return sID == id - }) - if len(trig.subscriptions) == 0 { - r.closeTrigger(trig.id, SubscriptionCloseKindNormal) - } - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - } + return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveClient(id int64) { +func (r *Resolver) handleRemoveClient(id int64) (int, []subscriptionFinalizer, []context.CancelFunc, int) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } removed := 0 - for u := range r.triggers { - removed += r.closeTriggerSubscriptions(u, SubscriptionCloseKindNormal, func(sID SubscriptionIdentifier) bool { - return sID.ConnectionID == id - }) - if len(r.triggers[u].subscriptions) == 0 { - r.closeTrigger(r.triggers[u].id, SubscriptionCloseKindNormal) + finalizers := make([]subscriptionFinalizer, 0) + cancels := make([]context.CancelFunc, 0) + triggerDec := 0 + idsForConn := r.subscriptionsByConnection[id] + ids := make([]SubscriptionIdentifier, 0, len(idsForConn)) + for sid := range idsForConn { + ids = append(ids, sid) + } + for _, sid := range ids { + rem, fz, cancel, initialized := r.removeSubscriptionByID(sid) + removed += rem + finalizers = append(finalizers, fz...) + if cancel != nil { + cancels = append(cancels, cancel) + if initialized { + triggerDec++ + } } } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - } + return removed, finalizers, cancels, triggerDec } -func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { - trig, ok := r.triggers[id] +// removeSubscriptionByID removes a single subscription by id. +// r.mu must be held by the caller. +func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { + s, ok := r.subscriptionsByID[id] if !ok { - return - } - if r.options.Debug { - fmt.Printf("resolver:trigger:update:%d\n", id) + return 0, nil, nil, false } - for c, s := range trig.subscriptions { - r.sendUpdateToSubscription(data, c, s) - } -} - -func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { - trig, ok := r.triggers[id] + trig, ok := r.triggers[s.triggerID] if !ok { - return + r.removeSubscriptionIndex(id) + return 0, nil, nil, false } - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) + trig.mu.Lock() + _, ok = trig.subscriptions[id] + if !ok { + trig.mu.Unlock() + r.removeSubscriptionIndex(id) + return 0, nil, nil, false } - for c, s := range trig.subscriptions { - if s.id != subIdentifier { - continue - } - r.sendUpdateToSubscription(data, c, s) - break + var finalizers []subscriptionFinalizer + if s.removed.CompareAndSwap(false, true) { + finalizers = append(finalizers, subscriptionFinalizer{ + sub: s, + }) } -} + delete(trig.subscriptions, id) + empty := len(trig.subscriptions) == 0 + trig.mu.Unlock() -func (r *Resolver) sendUpdateToSubscription(data []byte, c *Context, s *sub) { - if err := c.ctx.Err(); err != nil { - return // no need to schedule an event update when the client already disconnected - } - skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) - if err != nil { - r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) - return - } - if skip { - return - } + r.removeSubscriptionIndex(id) - fn := func() { - r.executeSubscriptionUpdate(c, s, data) + var cancel context.CancelFunc + initialized := false + if empty { + delete(r.triggers, trig.id) + cancel = trig.cancel + initialized = trig.initialized } - select { - case <-r.ctx.Done(): - // Skip sending all events if the resolver is shutting down - return - case <-c.ctx.Done(): - // Skip sending the event if the client disconnected - case s.workChan <- workItem{fn, false}: - // Send the event to the subscription worker - } + return 1, finalizers, cancel, initialized } -func (r *Resolver) closeTrigger(id uint64, kind SubscriptionCloseKind) { - if r.options.Debug { - fmt.Printf("resolver:trigger:close:%d\n", id) - } +// detachTriggerLocked removes all subscriptions for the trigger and removes the trigger from resolver maps. +// r.mu must be held by the caller. +func (r *Resolver) detachTriggerLocked(id uint64) (int, []subscriptionFinalizer, context.CancelFunc, bool) { trig, ok := r.triggers[id] if !ok { - return + return 0, nil, nil, false } - removed := r.closeTriggerSubscriptions(id, kind, nil) + finalizers := make([]subscriptionFinalizer, 0, len(trig.subscriptions)) + removed := 0 - // Cancels the async datasource and cleanup the connection - trig.cancel() + trig.mu.Lock() + for sid, s := range trig.subscriptions { + if s.removed.CompareAndSwap(false, true) { + finalizers = append(finalizers, subscriptionFinalizer{ + sub: s, + }) + } + delete(trig.subscriptions, sid) + r.removeSubscriptionIndex(sid) + removed++ + } + trig.mu.Unlock() delete(r.triggers, id) - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if trig.initialized { - r.reporter.TriggerCountDec(1) - } - } + return removed, finalizers, trig.cancel, trig.initialized } -func (r *Resolver) completeTrigger(id uint64) { - if r.options.Debug { - fmt.Printf("resolver:trigger:complete:%d\n", id) - } +// pendingWrite holds the context and subscription for a deferred write outside the lock. +type pendingSubscriptionWrite struct { + sub *subscriptionState +} +// handleTriggerUpdate sends data to all subscriptions of a trigger using snapshot-and-release. +// The lock is released before performing I/O to avoid deadlocks when executeSubscriptionUpdate +// calls AsyncUnsubscribeSubscription on flush failure. +func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { + r.mu.Lock() trig, ok := r.triggers[id] + r.mu.Unlock() if !ok { return } + if r.options.Debug { + fmt.Printf("resolver:trigger:update:%d\n", id) + } - removed := r.completeTriggerSubscriptions(id, nil) - - // Cancels the async datasource and cleanup the connection - trig.cancel() - - delete(r.triggers, id) + var pending []pendingSubscriptionWrite + trig.mu.Lock() + for _, s := range trig.subscriptions { + if s.ctx.ctx.Err() != nil { + continue + } + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) + if err != nil { + r.asyncErrorWriter.WriteError(s.ctx, err, s.resolve.Response, s.writer) + continue + } + if skip { + continue + } + pending = append(pending, pendingSubscriptionWrite{s}) + } + trig.mu.Unlock() - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if trig.initialized { - r.reporter.TriggerCountDec(1) + var wg sync.WaitGroup + for _, pw := range pending { + if pw.sub.removed.Load() { + continue } + wg.Go(func() { + r.executeSubscriptionUpdate(pw.sub.ctx, pw.sub, data) + }) } + wg.Wait() } -func (r *Resolver) completeTriggerSubscriptions(id uint64, completeMatcher func(a SubscriptionIdentifier) bool) int { +// handleUpdateSubscription sends data to a single subscription using snapshot-and-release. +func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { + r.mu.Lock() trig, ok := r.triggers[id] + r.mu.Unlock() if !ok { - return 0 + return } - removed := 0 - for c, s := range trig.subscriptions { - if completeMatcher != nil && !completeMatcher(s.id) { - continue - } - - // Send a work item to complete the subscription - s.workChan <- workItem{s.complete, true} - // Because the event loop is single threaded, we can safely close the channel from this sender - // The subscription worker will finish processing all events before the channel is closed. - close(s.workChan) - - // Important because we remove the subscription from the trigger on the same goroutine - // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. - delete(trig.subscriptions, c) + if r.options.Debug { + fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) + } - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) + var target *subscriptionState + trig.mu.Lock() + s, ok := trig.subscriptions[subIdentifier] + if ok { + if s.ctx.ctx.Err() == nil { + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) + if err != nil { + r.asyncErrorWriter.WriteError(s.ctx, err, s.resolve.Response, s.writer) + } else if !skip { + target = s + } } + } + trig.mu.Unlock() - removed++ + if target != nil && !target.removed.Load() { + r.executeSubscriptionUpdate(target.ctx, target, data) } - return removed } -func (r *Resolver) closeTriggerSubscriptions(id uint64, closeKind SubscriptionCloseKind, closeMatcher func(a SubscriptionIdentifier) bool) int { +func (r *Resolver) heartbeatTriggerSubscriptions(id uint64) { + r.mu.Lock() trig, ok := r.triggers[id] + r.mu.Unlock() if !ok { - return 0 + return } - removed := 0 - for c, s := range trig.subscriptions { - if closeMatcher != nil && !closeMatcher(s.id) { + + targets := make([]*subscriptionState, 0, len(trig.subscriptions)) + trig.mu.RLock() + for _, s := range trig.subscriptions { + if !s.heartbeat || s.removed.Load() { continue } - - // Send a work item to close the subscription - s.workChan <- workItem{func() { s.close(closeKind) }, true} - - // Because the event loop is single threaded, we can safely close the channel from this sender - // The subscription worker will finish processing all events before the channel is closed. - close(s.workChan) - - // Important because we remove the subscription from the trigger on the same goroutine - // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. - delete(trig.subscriptions, c) - - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) + if time.Since(time.Unix(0, s.lastWriteTime.Load())) < r.heartbeatInterval { + continue } + targets = append(targets, s) + } + trig.mu.RUnlock() - removed++ + for _, s := range targets { + r.executeSubscriptionHeartbeat(s) } - return removed } -func (r *Resolver) handleShutdown() { +func (r *Resolver) shutdownResolver() { if r.options.Debug { fmt.Printf("resolver:trigger:shutdown\n") } + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() + return + } + + r.shutdown = true + triggerIDs := make([]uint64, 0, len(r.triggers)) for id := range r.triggers { - r.closeTrigger(id, SubscriptionCloseKindGoingAway) + triggerIDs = append(triggerIDs, id) + } + + allFinalizers := make([]subscriptionFinalizer, 0) + cancels := make([]context.CancelFunc, 0, len(triggerIDs)) + removedTotal := 0 + triggerDec := 0 + + for _, id := range triggerIDs { + removed, finalizers, cancel, initialized := r.detachTriggerLocked(id) + removedTotal += removed + allFinalizers = append(allFinalizers, finalizers...) + if cancel != nil { + cancels = append(cancels, cancel) + } + if initialized { + triggerDec++ + } + } + + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removedTotal) + if triggerDec > 0 { + r.reporter.TriggerCountDec(triggerDec) + } + } + + r.triggers = make(map[uint64]*trigger) + r.subscriptionsByID = make(map[SubscriptionIdentifier]*subscriptionState) + r.subscriptionsByConnection = make(map[int64]map[SubscriptionIdentifier]*subscriptionState) + r.mu.Unlock() + + runSubscriptionFinalizers(allFinalizers) + for _, cancel := range cancels { + cancel() } + if r.options.Debug { fmt.Printf("resolver:trigger:shutdown:done\n") } - r.triggers = make(map[uint64]*trigger) +} + +func (r *Resolver) heartbeatLoop() { + ticker := time.NewTicker(r.heartbeatInterval) + defer ticker.Stop() + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.sendTriggerHeartbeats() + } + } +} + +func (r *Resolver) sendTriggerHeartbeats() { + r.mu.Lock() + triggerIDs := make([]uint64, 0, len(r.triggers)) + for id := range r.triggers { + triggerIDs = append(triggerIDs, id) + } + r.mu.Unlock() + + for _, id := range triggerIDs { + r.heartbeatTriggerSubscriptions(id) + } } type SubscriptionIdentifier struct { @@ -1170,66 +1175,73 @@ type SubscriptionIdentifier struct { SubscriptionID int64 } -func (r *Resolver) AsyncCompleteSubscription(id SubscriptionIdentifier) error { - select { - case <-r.ctx.Done(): +func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindCompleteSubscription, - }: + } + // Grab the sub before removal so we can send a "complete" frame after releasing r.mu. + sub := r.subscriptionsByID[id] + removed, finalizers, cancel, initialized := r.handleCompleteSubscription(id) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removed) + if cancel != nil && initialized { + r.reporter.TriggerCountDec(1) + } + } + r.mu.Unlock() + // Send "complete" to the downstream writer under writeMu. + // This ensures any in-flight data write finishes before the complete is sent, + // matching the old behavior where the worker goroutine called sub.complete(). + if sub != nil { + sub.complete() + } + runSubscriptionFinalizers(finalizers) + if cancel != nil { + cancel() } return nil } -func (r *Resolver) AsyncUnsubscribeSubscription(id SubscriptionIdentifier) error { - select { - case <-r.ctx.Done(): +func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindRemoveSubscription, - }: - default: - // In the event we cannot insert immediately, defer insertion a goroutine, this should prevent deadlocks, at the cost of goroutine creation. - go func() { - select { - case <-r.ctx.Done(): - return - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindRemoveSubscription, - }: - } - }() + } + removed, finalizers, cancel, initialized := r.handleRemoveSubscription(id) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removed) + if cancel != nil && initialized { + r.reporter.TriggerCountDec(1) + } + } + r.mu.Unlock() + runSubscriptionFinalizers(finalizers) + if cancel != nil { + cancel() } return nil } -func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { - select { - case <-r.ctx.Done(): +func (r *Resolver) UnsubscribeClient(connectionID int64) error { + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: SubscriptionIdentifier{ - ConnectionID: connectionID, - }, - kind: subscriptionEventKindRemoveClient, - }: - default: - // In the event we cannot insert immediately, defer insertion a goroutine, this should prevent deadlocks, at the cost of goroutine creation. - go func() { - select { - case <-r.ctx.Done(): - return - case r.events <- subscriptionEvent{ - id: SubscriptionIdentifier{ - ConnectionID: connectionID, - }, - kind: subscriptionEventKindRemoveClient, - }: - } - }() + } + removed, finalizers, cancels, triggerDec := r.handleRemoveClient(connectionID) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removed) + if triggerDec > 0 { + r.reporter.TriggerCountDec(triggerDec) + } + } + r.mu.Unlock() + runSubscriptionFinalizers(finalizers) + for _, cancel := range cancels { + cancel() } return nil } @@ -1293,7 +1305,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ - ConnectionID: ConnectionIDs.Inc(), + ConnectionID: ConnectionIDs.Add(1), SubscriptionID: 0, } if r.options.Debug { @@ -1302,43 +1314,38 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ completed := make(chan struct{}) - select { - case <-r.ctx.Done(): - // Stop processing if the resolver is shutting down + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindAddSubscription, - addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, - sourceName: subscription.Trigger.SourceName, - headers: headers, - }, - }: } + r.handleAddSubscription(triggerID, &addSubscription{ + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, + headers: headers, + }) + r.mu.Unlock() // This will immediately block until one of the following conditions is met: select { case <-ctx.ctx.Done(): // Client disconnected, request context canceled. - // We will ignore the error and remove the subscription in the next step. - + _ = r.UnsubscribeSubscription(id) select { case <-completed: // Wait for the subscription to be completed to avoid race conditions // with go sdk request shutdown. case <-r.ctx.Done(): - // Resolver shutdown, no way to gracefully shut down the subscription + // Resolver shutdown return r.ctx.Err() } case <-r.ctx.Done(): - // Resolver shutdown, no way to gracefully shut down the subscription - // because the event loop is not running anymore and shutdown all triggers + subscriptions + // Resolver shutdown return r.ctx.Err() case <-completed: } @@ -1348,12 +1355,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } // Remove the subscription when the client disconnects. - - r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindRemoveSubscription, - id: id, - } + _ = r.UnsubscribeSubscription(id) return nil } @@ -1395,30 +1397,32 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } + if err := ctx.ctx.Err(); err != nil { + return err + } + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) - select { - case <-r.ctx.Done(): - // Stop resolving if the resolver is shutting down + r.mu.Lock() + if err := ctx.ctx.Err(); err != nil { + r.mu.Unlock() + return err + } + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case <-ctx.ctx.Done(): - // Stop resolving if the client is gone - return ctx.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindAddSubscription, - addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), - sourceName: subscription.Trigger.SourceName, - headers: headers, - }, - }: } + r.handleAddSubscription(triggerID, &addSubscription{ + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, + headers: headers, + }) + r.mu.Unlock() return nil } @@ -1447,44 +1451,36 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip type subscriptionUpdater struct { debug bool triggerID uint64 - ch chan subscriptionEvent + resolver *Resolver ctx context.Context subsFn func() map[context.Context]SubscriptionIdentifier } func (s *subscriptionUpdater) Update(data []byte) { + if s.ctx.Err() != nil { + return + } if s.debug { fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) } + s.resolver.handleTriggerUpdate(s.triggerID, data) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) Heartbeat() { + if s.ctx.Err() != nil { return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerUpdate, - data: data, - }: } + s.resolver.heartbeatTriggerSubscriptions(s.triggerID) } func (s *subscriptionUpdater) UpdateSubscription(id SubscriptionIdentifier, data []byte) { + if s.ctx.Err() != nil { + return + } if s.debug { fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) } - - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done - return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindUpdateSubscription, - data: data, - id: id, - }: - } + s.resolver.handleUpdateSubscription(s.triggerID, data, id) } func (s *subscriptionUpdater) Subscriptions() map[context.Context]SubscriptionIdentifier { @@ -1492,81 +1488,50 @@ func (s *subscriptionUpdater) Subscriptions() map[context.Context]SubscriptionId } func (s *subscriptionUpdater) Complete() { - if s.debug { - fmt.Printf("resolver:subscription_updater:complete:%d\n", s.triggerID) - } - - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done + if s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:complete:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerComplete, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:complete:sent_event:%d\n", s.triggerID) - } } -} - -func (s *subscriptionUpdater) Close(kind SubscriptionCloseKind) { if s.debug { - fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:complete:%d\n", s.triggerID) } + s.resolver.handleTriggerComplete(s.triggerID) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) Error(data []byte) { + if s.ctx.Err() != nil { if s.debug { - fmt.Printf("resolver:subscription_updater:close:skip:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:error:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerClose, - closeKind: kind, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:close:sent_event:%d\n", s.triggerID) - } } + if s.debug { + fmt.Printf("resolver:subscription_updater:error:%d\n", s.triggerID) + } + s.resolver.handleTriggerError(s.triggerID, data) } -func (s *subscriptionUpdater) CloseSubscription(kind SubscriptionCloseKind, id SubscriptionIdentifier) { +func (s *subscriptionUpdater) Done() { if s.debug { - fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:done:%d\n", s.triggerID) } + s.resolver.doneTriggerFromUpdater(s.triggerID) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) CloseSubscription(id SubscriptionIdentifier) { + if s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:close:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindRemoveSubscription, - closeKind: kind, - id: id, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:close:sent_event:%d\n", s.triggerID) - } } -} + if s.debug { + fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + } -type subscriptionEvent struct { - triggerID uint64 - id SubscriptionIdentifier - kind subscriptionEventKind - data []byte - addSubscription *addSubscription - closeKind SubscriptionCloseKind + _ = s.resolver.UnsubscribeSubscription(id) } type addSubscription struct { @@ -1580,32 +1545,22 @@ type addSubscription struct { headers http.Header } -type subscriptionEventKind int - -const ( - subscriptionEventKindUnknown subscriptionEventKind = iota - subscriptionEventKindTriggerUpdate - subscriptionEventKindTriggerComplete - subscriptionEventKindAddSubscription - subscriptionEventKindRemoveSubscription - subscriptionEventKindCompleteSubscription - subscriptionEventKindRemoveClient - subscriptionEventKindTriggerInitialized - subscriptionEventKindTriggerClose - subscriptionEventKindUpdateSubscription -) - type SubscriptionUpdater interface { // Update sends an update to the client. It is not guaranteed that the update is sent immediately. Update(data []byte) // UpdateSubscription sends an update to a single subscription. It is not guaranteed that the update is sent immediately. UpdateSubscription(id SubscriptionIdentifier, data []byte) - // Complete also takes care of cleaning up the trigger and all subscriptions. No more updates should be sent after calling Complete. + // Complete delivers a "subscription done" signal to all subscriptions on the trigger. + // Does not perform cleanup — call Done() after Complete(). Complete() - // Close closes the subscription and cleans up the trigger and all subscriptions. No more updates should be sent after calling Close. - Close(kind SubscriptionCloseKind) + // Error delivers a terminal error to all subscriptions on the trigger, bypassing the resolve pipeline. + // Does not perform cleanup — call Done() after Error(). + Error(data []byte) + // Done performs internal cleanup: detaches the trigger, closes completed channels. + // Must always be the final call. Does not send any downstream messages. + Done() // CloseSubscription closes a single subscription. No more updates should be sent to that subscription after calling CloseSubscription. - CloseSubscription(kind SubscriptionCloseKind, id SubscriptionIdentifier) + CloseSubscription(id SubscriptionIdentifier) // Subscriptions return all the subscriptions associated to this Updater Subscriptions() map[context.Context]SubscriptionIdentifier } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 82a8e1e635..91123f7953 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -1651,7 +1651,6 @@ func testFnWithPostEvaluationAndOptions(opts ResolverOptions, fn func(t *testing } func TestResolver_ResolveGraphQLResponse(t *testing.T) { - t.Run("empty graphql response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ Data: &Object{ @@ -2748,8 +2747,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("when data null and errors present not nullable array should result to null data upstream error and resolve error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ - FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource( - `{"errors":[{"message":"Could not get name","locations":[{"line":3,"column":5}],"path":["todos","0","name"]}],"data":null}`), + FetchConfiguration: FetchConfiguration{ + DataSource: FakeDataSource( + `{"errors":[{"message":"Could not get name","locations":[{"line":3,"column":5}],"path":["todos","0","name"]}],"data":null}`), PostProcessing: PostProcessingConfiguration{ SelectResponseDataPath: []string{"data"}, SelectResponseErrorsPath: []string{"errors"}, @@ -3599,7 +3599,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -3646,7 +3645,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, *NewContext(context.Background()), `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"foo","product":{"upc":"top-1","name":"Trilby"}},{"body":"bar","product":{"upc":"top-2","name":"Fedora"}},{"body":"baz","product":null},{"body":"bat","product":{"upc":"top-4","name":"Boater"}},{"body":"bal","product":{"upc":"top-5","name":"Top Hat"}},{"body":"ban","product":{"upc":"top-6","name":"Bowler"}}]}}}` })) t.Run("federation with fetch error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) userService.EXPECT(). Load(gomock.Any(), gomock.Any(), gomock.Any()). @@ -3787,7 +3785,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -3833,7 +3830,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, *NewContext(context.Background()), `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query.me.reviews.@.product', Reason: no data or errors in response."},{"message":"Cannot return null for non-nullable field 'Query.me.reviews.product.name'.","path":["me","reviews",0,"product","name"]},{"message":"Cannot return null for non-nullable field 'Query.me.reviews.product.name'.","path":["me","reviews",1,"product","name"]}],"data":{"me":{"id":"1234","username":"Me","reviews":[null,null]}}}` })) t.Run("federation with fetch error and non null fields inside an array", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) userService.EXPECT(). Load(gomock.Any(), gomock.Any(), gomock.Any()). @@ -3973,7 +3969,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -4249,7 +4244,6 @@ func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLRe } func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { - t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string) { resolveCtx := NewContext(context.Background()) return &GraphQLResponse{ @@ -5344,7 +5338,8 @@ type SubscriptionRecorder struct { buf *bytes.Buffer messages []string complete atomic.Bool - closed atomic.Bool + done atomic.Bool + errors [][]byte mux sync.Mutex onFlush func(p []byte) } @@ -5399,15 +5394,15 @@ func (s *SubscriptionRecorder) AwaitComplete(t *testing.T, timeout time.Duration } } -func (s *SubscriptionRecorder) AwaitClosed(t *testing.T, timeout time.Duration) { +func (s *SubscriptionRecorder) AwaitDone(t *testing.T, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for { - if s.closed.Load() { + if s.done.Load() { return } if time.Now().After(deadline) { - t.Fatalf("timed out waiting for close") + t.Fatalf("timed out waiting for done") } time.Sleep(time.Millisecond * 10) } @@ -5438,8 +5433,10 @@ func (s *SubscriptionRecorder) Heartbeat() error { return nil } -func (s *SubscriptionRecorder) Close(_ SubscriptionCloseKind) { - s.closed.Store(true) +func (s *SubscriptionRecorder) Error(data []byte) { + s.mux.Lock() + s.errors = append(s.errors, data) + s.mux.Unlock() } func (s *SubscriptionRecorder) Messages() []string { @@ -5499,7 +5496,7 @@ func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, upd for { select { case <-ctx.ctx.Done(): - updater.Complete() + updater.Done() f.isDone.Store(true) return default: @@ -5508,6 +5505,7 @@ func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, upd if done { time.Sleep(f.delay) updater.Complete() + updater.Done() f.isDone.Store(true) return } @@ -5526,7 +5524,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } setup := func(ctx context.Context, stream SubscriptionDataSource) (*Resolver, *GraphQLSubscription, *SubscriptionRecorder, SubscriptionIdentifier) { - fetches := Sequence() fetches.Trigger = &FetchTreeNode{ Kind: FetchTreeNodeKindTrigger, @@ -5781,7 +5778,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { messages := recorder.Messages() assert.Greater(t, len(messages), 2) - time.Sleep(resolver.heartbeatInterval) + time.Sleep(10 * time.Millisecond) // Validate that despite the time, we don't see any heartbeats sent assert.Contains(t, messages, `{"data":{"counter":0}}`) assert.Contains(t, messages, `{"data":{"counter":1}}`) @@ -5818,7 +5815,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { recorder.AwaitComplete(t, defaultTimeout) - time.Sleep(resolver.heartbeatInterval) + time.Sleep(10 * time.Millisecond) assert.Len(t, recorder.Messages(), 20) @@ -5921,9 +5918,8 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeSubscription(id) + err = resolver.UnsubscribeSubscription(id) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) @@ -5946,9 +5942,30 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeClient(id.ConnectionID) + err = resolver.UnsubscribeClient(id.ConnectionID) + assert.NoError(t, err) + fakeStream.AwaitIsDone(t, defaultTimeout) + }) + + t.Run("should stop stream on unsubscribe client with close reason", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false + }, time.Millisecond*10, nil, nil) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := Context{ + ctx: context.Background(), + } + + err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) + assert.NoError(t, err) + recorder.AwaitAnyMessageCount(t, defaultTimeout) + err = resolver.UnsubscribeClient(id.ConnectionID) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) @@ -6022,13 +6039,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, _, _ := setup(c, fakeStream) - ctx := &Context{ - ctx: context.Background(), - ExecutionOptions: ExecutionOptions{ - SendHeartbeat: true, - }, - } - const numSubscriptions = 2 var resolverCompleted atomic.Uint32 var recorderCompleted atomic.Uint32 @@ -6040,6 +6050,15 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } recorder.complete.Store(false) + // Each subscription needs its own Context so they get separate entries + // in the trigger's subscriptions map (keyed by *Context). + subCtx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + go func() { defer recorderCompleted.Add(1) @@ -6049,7 +6068,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { go func() { defer resolverCompleted.Add(1) - err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder) + err := resolver.ResolveGraphQLSubscription(subCtx, plan, recorder) assert.ErrorIs(t, err, context.Canceled) }() } @@ -6090,9 +6109,8 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeSubscription(id) + err = resolver.UnsubscribeSubscription(id) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/resolver_subscription_test.go similarity index 69% rename from v2/pkg/engine/resolve/event_loop_test.go rename to v2/pkg/engine/resolve/resolver_subscription_test.go index ba8b7c8e2f..9e8ac319fd 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/resolver_subscription_test.go @@ -1,7 +1,9 @@ package resolve import ( + "bytes" "context" + "errors" "io" "net/http" "sync" @@ -15,7 +17,6 @@ import ( type FakeErrorWriter struct{} func (f *FakeErrorWriter) WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) { - } type FakeSubscriptionWriter struct { @@ -23,7 +24,6 @@ type FakeSubscriptionWriter struct { buf []byte writtenMessages []string completed bool - closed bool messageCountOnComplete int } @@ -59,11 +59,7 @@ func (f *FakeSubscriptionWriter) Heartbeat() error { return nil } -func (f *FakeSubscriptionWriter) Close(SubscriptionCloseKind) { - f.mu.Lock() - defer f.mu.Unlock() - f.closed = true - f.messageCountOnComplete = len(f.writtenMessages) +func (f *FakeSubscriptionWriter) Error([]byte) { } type FakeSource struct { @@ -80,17 +76,35 @@ func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, upda } } updater.Complete() + updater.Done() }() return nil } +type FailingHeartbeatWriter struct{} + +func (f *FailingHeartbeatWriter) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (f *FailingHeartbeatWriter) Flush() error { + return nil +} + +func (f *FailingHeartbeatWriter) Complete() {} + +func (f *FailingHeartbeatWriter) Heartbeat() error { + return errors.New("heartbeat failed") +} + +func (f *FailingHeartbeatWriter) Error([]byte) {} + type TestReporter struct { triggers atomic.Int64 subscriptions atomic.Int64 } func (t *TestReporter) SubscriptionUpdateSent() { - } func (t *TestReporter) SubscriptionCountInc(count int) { @@ -110,7 +124,6 @@ func (t *TestReporter) TriggerCountDec(count int) { } func TestEventLoop(t *testing.T) { - resolverCtx, stopEventLoop := context.WithCancel(context.Background()) t.Cleanup(stopEventLoop) @@ -187,5 +200,59 @@ func TestEventLoop(t *testing.T) { require.Equal(t, int64(0), subscriptionCount) return true }, time.Second, time.Millisecond*10) +} + +func TestResolver_HeartbeatError_DoesNotDeadlockOnUnsubscribe(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1, + AsyncErrorWriter: &FakeErrorWriter{}, + SubscriptionHeartbeatInterval: time.Millisecond, + }) + + subCtx := (&Context{}).WithContext(context.Background()) + subID := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + triggerID := uint64(42) + s := &subscriptionState{ + triggerID: triggerID, + ctx: subCtx, + writer: &FailingHeartbeatWriter{}, + id: subID, + heartbeat: true, + completed: make(chan struct{}), + } + resolver.mu.Lock() + resolver.triggers[triggerID] = &trigger{ + id: triggerID, + cancel: func() {}, + subscriptions: map[SubscriptionIdentifier]*subscriptionState{subID: s}, + updateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), + } + resolver.subscriptionsByID[subID] = s + resolver.subscriptionsByConnection[subID.ConnectionID] = map[SubscriptionIdentifier]*subscriptionState{subID: s} + resolver.mu.Unlock() + + done := make(chan struct{}) + go func() { + resolver.heartbeatTriggerSubscriptions(triggerID) + close(done) + }() + + select { + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fatal("heartbeatTriggerSubscriptions deadlocked after heartbeat error") + } + + select { + case <-s.completed: + case <-time.After(250 * time.Millisecond): + t.Fatal("subscription was not closed after heartbeat failure") + } } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index d8af8d017b..36d2f3a7e1 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -3,8 +3,6 @@ package resolve import ( "io" - "github.com/gobwas/ws" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" ) @@ -68,23 +66,12 @@ type ResponseWriter interface { io.Writer } -type SubscriptionCloseKind struct { - WSCode ws.StatusCode - Reason string -} - -var ( - SubscriptionCloseKindNormal SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusNormalClosure, "Normal closure"} - SubscriptionCloseKindDownstreamServiceError SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusGoingAway, "Downstream service error"} - SubscriptionCloseKindGoingAway SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusGoingAway, "Going away"} -) - type SubscriptionResponseWriter interface { ResponseWriter Flush() error Complete() Heartbeat() error - Close(kind SubscriptionCloseKind) + Error(data []byte) } func writeGraphqlResponse(buf *BufPair, writer io.Writer, ignoreData bool) (err error) { From 5d43982ae80e250bbd7c070f6e9bf75440fdc71f Mon Sep 17 00:00:00 2001 From: endigma Date: Tue, 17 Mar 2026 18:42:27 +0000 Subject: [PATCH 02/52] wip: error handling improvements --- .../graphql_subscription_client.go | 53 +++++++++++-------- .../graphql_subscription_client_test.go | 6 +-- .../subscriptionclient/common/message.go | 20 ++++++- .../subscriptionclient/exports.go | 9 ++++ .../subscriptionclient/protocol/protocol.go | 10 ++-- .../subscriptionclient/transport/sse_conn.go | 24 ++++----- .../transport/sse_conn_test.go | 8 +-- .../transport/sse_transport_test.go | 14 ++--- .../subscriptionclient/transport/ws_conn.go | 2 +- .../transport/ws_conn_test.go | 6 +-- .../transport/ws_transport.go | 13 ++++- .../transport/ws_transport_test.go | 8 +-- 12 files changed, 106 insertions(+), 67 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 4d74ac2d1c..7da4a34f15 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -150,6 +150,11 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu msgCh, cancel, err := c.client.Subscribe(ctx.Context(), req, opts) if err != nil { + if isUpstreamError(err) { + updater.Error(formatUpstreamServiceError(err)) + updater.Done() + return nil + } return err } @@ -174,32 +179,28 @@ func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *clien return } - if msg.Err != nil { - if isConnectionError(msg.Err) { - updater.Error(formatUpstreamServiceError(msg.Err)) - } else { - updater.Error(formatSubscriptionError(msg.Err)) - } + switch msg.Type { + case client.MessageTypeConnectionError: + updater.Error(formatUpstreamServiceError(msg.Err)) + updater.Done() + return + + case client.MessageTypeError: + data, _ := json.Marshal(msg.Payload) + updater.Error(data) updater.Done() return - } - if msg.Payload != nil { + case client.MessageTypeData: data, err := json.Marshal(msg.Payload) if err != nil { updater.Error(formatSubscriptionError(err)) updater.Done() return } - if msg.Done { - updater.Error(data) - updater.Done() - return - } updater.Update(data) - } - if msg.Done { + case client.MessageTypeComplete: updater.Complete() updater.Done() return @@ -208,9 +209,13 @@ func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *clien } } -func isConnectionError(err error) bool { +// isUpstreamError reports whether err is a connection-level upstream error +// that should be reported to the client as an UPSTREAM_SERVICE_ERROR. +func isUpstreamError(err error) bool { return errors.Is(err, client.ErrConnectionClosed) || errors.Is(err, client.ErrConnectionError) || + errors.Is(err, client.ErrInitFailed) || + errors.Is(err, client.ErrDialFailed) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } @@ -267,8 +272,9 @@ func mapWSSubprotocol(proto string) client.WSSubprotocol { } // formatUpstreamServiceError formats a connection-level error as a GraphQL error -// response with the UPSTREAM_SERVICE_ERROR extension code. If the error is a -// WebSocket close error, the close code and reason are included in extensions. +// response with the UPSTREAM_SERVICE_ERROR extension code. If the error chain +// contains a WebSocket close error, the close code and reason are included in +// extensions. func formatUpstreamServiceError(err error) []byte { type errorExtensions struct { Code string `json:"code"` @@ -281,18 +287,21 @@ func formatUpstreamServiceError(err error) []byte { Extensions errorExtensions `json:"extensions"` } - ext := errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"} + gqlErr := graphqlError{ + Message: "upstream service error", + Extensions: errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"}, + } var closeErr websocket.CloseError if errors.As(err, &closeErr) { - ext.CloseCode = int(closeErr.Code) - ext.Reason = closeErr.Reason + gqlErr.Extensions.CloseCode = int(closeErr.Code) + gqlErr.Extensions.Reason = closeErr.Reason } resp := struct { Errors []graphqlError `json:"errors"` }{ - Errors: []graphqlError{{Message: "upstream service closed the connection", Extensions: ext}}, + Errors: []graphqlError{gqlErr}, } data, _ := json.Marshal(resp) return data diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 956cd80b73..1e3be99306 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -47,7 +47,7 @@ func TestReadLoopErrorHandling(t *testing.T) { t.Run("connection errors deliver error and done without updates", func(t *testing.T) { updater := &testBridgeUpdater{} msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Err: client.ErrConnectionClosed} + msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed} close(msgCh) subClient := &subscriptionClientV2{} @@ -62,7 +62,7 @@ func TestReadLoopErrorHandling(t *testing.T) { t.Run("non-connection errors deliver error and done without updates", func(t *testing.T) { updater := &testBridgeUpdater{} msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Err: errors.New("validation failed")} + msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: errors.New("validation failed")} close(msgCh) subClient := &subscriptionClientV2{} @@ -102,7 +102,7 @@ func TestReadLoopErrorHandling(t *testing.T) { t.Run("done message calls complete then done", func(t *testing.T) { updater := &testBridgeUpdater{} msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Done: true} + msgCh <- &client.Message{Type: client.MessageTypeComplete} close(msgCh) subClient := &subscriptionClientV2{} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go index 70c48654d7..8500f09669 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -7,10 +7,26 @@ import ( var ErrConnectionClosed = errors.New("connection closed") +// MessageType identifies the kind of message delivered on a subscription channel. +type MessageType int + +const ( + MessageTypeUnknown MessageType = iota + MessageTypeData // normal data payload + MessageTypeError // GraphQL-level error from server (has Payload) + MessageTypeComplete // subscription completed normally + MessageTypeConnectionError // connection-level error (has Err) +) + +// IsTerminal reports whether the message type signals end-of-stream. +func (t MessageType) IsTerminal() bool { + return t == MessageTypeError || t == MessageTypeComplete || t == MessageTypeConnectionError +} + type Message struct { + Type MessageType Payload *ExecutionResult - Err error - Done bool + Err error // only set when Type == MessageTypeConnectionError } type ExecutionResult struct { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go index 4338508522..44e8f788a5 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -10,6 +10,7 @@ import ( type ( Message = common.Message + MessageType = common.MessageType ExecutionResult = common.ExecutionResult Request = common.Request Options = common.Options @@ -21,6 +22,12 @@ type ( // Re-export constants. const ( + MessageTypeUnknown = common.MessageTypeUnknown + MessageTypeData = common.MessageTypeData + MessageTypeError = common.MessageTypeError + MessageTypeComplete = common.MessageTypeComplete + MessageTypeConnectionError = common.MessageTypeConnectionError + TransportWS = common.TransportWS TransportSSE = common.TransportSSE @@ -48,4 +55,6 @@ var ( ErrAckTimeout = protocol.ErrAckTimeout ErrAckNotReceived = protocol.ErrAckNotReceived ErrSubscriptionExists = transport.ErrSubscriptionExists + ErrDialFailed = transport.ErrDialFailed + ErrInitFailed = transport.ErrInitFailed ) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go index e54cfb6a30..4daddd2c24 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -39,16 +39,16 @@ type Message struct { func (m *Message) IntoClientMessage() *common.Message { switch m.Type { case MessageData: - return &common.Message{Payload: m.Payload} + return &common.Message{Type: common.MessageTypeData, Payload: m.Payload} case MessageError: if m.Payload != nil { - return &common.Message{Payload: m.Payload, Done: true} + return &common.Message{Type: common.MessageTypeError, Payload: m.Payload} } - return &common.Message{Err: m.Err, Done: true} + return &common.Message{Type: common.MessageTypeConnectionError, Err: m.Err} case MessageComplete: - return &common.Message{Done: true} + return &common.Message{Type: common.MessageTypeComplete} default: - return &common.Message{} + return &common.Message{Type: common.MessageTypeUnknown} } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go index 3e3b6892dc..59697b1490 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go @@ -71,7 +71,7 @@ func (c *sseConnection) readLoop() { return } - if msg.Done { + if msg.Type.IsTerminal() { return } } @@ -129,36 +129,30 @@ func (c *sseConnection) parseEvent(eventType string, data []byte) *common.Messag case "next": var resp common.ExecutionResult if err := json.Unmarshal(data, &resp); err != nil { - return &common.Message{ - Err: err, - Done: true, - } + return &common.Message{Type: common.MessageTypeConnectionError, Err: err} } - return &common.Message{Payload: &resp} + return &common.Message{Type: common.MessageTypeData, Payload: &resp} case "error": return &common.Message{ + Type: common.MessageTypeError, Payload: &common.ExecutionResult{Errors: data}, - Done: true, } case "complete": - return &common.Message{Done: true} + return &common.Message{Type: common.MessageTypeComplete} default: // Unknown event type or no event type specified - treat as data // This handles servers that send data without an event type if len(data) == 0 { - return &common.Message{Done: true} + return &common.Message{Type: common.MessageTypeComplete} } var resp common.ExecutionResult if err := json.Unmarshal(data, &resp); err != nil { - return &common.Message{ - Err: err, - Done: true, - } + return &common.Message{Type: common.MessageTypeConnectionError, Err: err} } - return &common.Message{Payload: &resp} + return &common.Message{Type: common.MessageTypeData, Payload: &resp} } } @@ -167,7 +161,7 @@ func (c *sseConnection) sendError(err error) { return } select { - case c.ch <- &common.Message{Err: err, Done: true}: + case c.ch <- &common.Message{Type: common.MessageTypeConnectionError, Err: err}: case <-c.done: } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go index c4f4cf3650..0d92981a34 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) func TestSSEConnection_ReadLoop(t *testing.T) { @@ -52,7 +54,7 @@ func TestSSEConnection_ReadLoop(t *testing.T) { msg := <-conn.ch require.Error(t, msg.Err) - require.True(t, msg.Done) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) }) t.Run("stops on complete event", func(t *testing.T) { @@ -69,11 +71,11 @@ func TestSSEConnection_ReadLoop(t *testing.T) { // First message msg1 := <-conn.ch assert.NotNil(t, msg1.Payload) - assert.False(t, msg1.Done) + assert.Equal(t, common.MessageTypeData, msg1.Type) // Complete message msg2 := <-conn.ch - assert.True(t, msg2.Done) + assert.Equal(t, common.MessageTypeComplete, msg2.Type) // Channel should close, no third message select { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go index a5236f7059..430767c746 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -77,7 +77,7 @@ func TestSSETransport_Subscribe(t *testing.T) { // Receive complete message msg = receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("passes custom headers", func(t *testing.T) { @@ -142,7 +142,7 @@ func TestSSETransport_Subscribe(t *testing.T) { msg := receiveWithTimeout(t, ch, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "Alice") - assert.False(t, msg.Done) + assert.Equal(t, common.MessageTypeData, msg.Type) }) t.Run("handles error event", func(t *testing.T) { @@ -164,7 +164,7 @@ func TestSSETransport_Subscribe(t *testing.T) { defer cancel() msg := receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeError, msg.Type) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Errors), "Something went wrong") }) @@ -188,7 +188,7 @@ func TestSSETransport_Subscribe(t *testing.T) { defer cancel() msg := receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) assert.Nil(t, msg.Err) assert.Nil(t, msg.Payload) }) @@ -263,7 +263,7 @@ func TestSSETransport_Subscribe(t *testing.T) { // Should only receive 2 messages (next + complete), not comments for msg := range ch { messageCount.Add(1) - if msg.Done { + if msg.Type.IsTerminal() { break } } @@ -616,7 +616,7 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) { defer cancel() msg := receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("rejects non-SSE content type", func(t *testing.T) { @@ -701,7 +701,7 @@ func TestSSETransport_GETMethod(t *testing.T) { // Receive complete message msg = receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("GET preserves existing query parameters", func(t *testing.T) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index e1627e54a9..3251d2461a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -246,7 +246,7 @@ func (c *wsConnection) shutdown(err error) { c.subs = make(map[string]chan<- *common.Message) c.subsMu.Unlock() - errMsg := &common.Message{Err: err, Done: true} + errMsg := &common.Message{Type: common.MessageTypeConnectionError, Err: err} for _, ch := range subs { select { case ch <- errMsg: diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index f21fc14911..21886942c8 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -129,7 +129,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { // Consume the message (blocking send requires consumer) msg := receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) assertChannelClosed(t, ch) }) @@ -362,11 +362,11 @@ func TestWSConnection_Close(t *testing.T) { // Consume messages (blocking send requires consumer) msg1 := receiveWithTimeout(t, ch1, 100*time.Millisecond) assert.Error(t, msg1.Err) - assert.True(t, msg1.Done) + assert.Equal(t, common.MessageTypeConnectionError, msg1.Type) msg2 := receiveWithTimeout(t, ch2, 100*time.Millisecond) assert.Error(t, msg2.Err) - assert.True(t, msg2.Done) + assert.Equal(t, common.MessageTypeConnectionError, msg2.Type) assertChannelClosed(t, ch1) assertChannelClosed(t, ch2) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 73b6932d6d..e4389fd407 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -18,6 +18,15 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) +// ErrDialFailed indicates that the WebSocket dial (TCP + HTTP upgrade) failed. +// The underlying cause is available via errors.Unwrap. +var ErrDialFailed = errors.New("websocket dial failed") + +// ErrInitFailed indicates that the GraphQL protocol init (connection_init / +// connection_ack handshake) failed after a successful WebSocket dial. The +// underlying cause (e.g. protocol.ErrAckTimeout) is available via errors.Unwrap. +var ErrInitFailed = errors.New("protocol init failed") + type ErrFailedUpgrade struct { URL string StatusCode int @@ -295,7 +304,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) return nil, ErrFailedUpgrade{URL: opts.Endpoint, StatusCode: resp.StatusCode} } - return nil, err + return nil, fmt.Errorf("%w: %w", ErrDialFailed, err) } wsConn.SetReadLimit(t.opts.readLimit) @@ -318,7 +327,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) abstractlogger.Error(err), ) wsConn.Close(websocket.StatusProtocolError, "init failed") - return nil, err + return nil, fmt.Errorf("%w: %w", ErrInitFailed, err) } t.opts.logger.Debug("wsTransport.dial", diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 414d17124d..4b62d4add2 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -58,7 +58,7 @@ func TestWSTransport_Subscribe(t *testing.T) { assert.Contains(t, string(msg.Payload.Data), "42") msg = receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("reuses connection for same endpoint", func(t *testing.T) { @@ -796,7 +796,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { assert.Contains(t, string(msg.Payload.Data), "42") msg = receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("handles keep-alive messages", func(t *testing.T) { @@ -841,7 +841,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { assert.NotNil(t, msg.Payload) msg = receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeComplete, msg.Type) }) t.Run("auto-negotiates to legacy when modern unavailable", func(t *testing.T) { @@ -952,7 +952,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { // Connection should be closed due to pong timeout, subscriber gets notified msg := receiveWithTimeout(t, ch, time.Second) - assert.True(t, msg.Done) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) assert.Error(t, msg.Err) assert.Eventually(t, func() bool { From a5e12fbb1772092f457c70ffe8ab7919e41dce5b Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 18 Mar 2026 15:05:18 +0000 Subject: [PATCH 03/52] collapse adapter readloop into a handler system instead of channels --- .../graphql_subscription_client.go | 77 +++---- .../graphql_subscription_client_test.go | 76 +++---- .../subscriptionclient/client.go | 16 +- .../subscriptionclient/client_test.go | 19 +- .../subscriptionclient/common/message.go | 4 + .../subscriptionclient/exports.go | 1 + .../subscriptionclient/transport/sse_conn.go | 37 ++-- .../transport/sse_conn_test.go | 75 +++---- .../transport/sse_transport.go | 18 +- .../transport/sse_transport_test.go | 196 ++++++++++-------- .../transport/transport_test.go | 58 +++++- .../subscriptionclient/transport/ws_conn.go | 41 ++-- .../transport/ws_conn_test.go | 88 ++++---- .../transport/ws_transport.go | 6 +- .../transport/ws_transport_test.go | 153 ++++++++------ v2/pkg/engine/resolve/resolve.go | 32 ++- 16 files changed, 462 insertions(+), 435 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 7da4a34f15..7b13b6f547 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -141,14 +141,36 @@ func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClien } // Subscribe implements GraphQLSubscriptionClient. -// It bridges the channel-based new client API to the callback-based updater interface. func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { opts, req, err := convertToClientOptions(options) if err != nil { return err } - msgCh, cancel, err := c.client.Subscribe(ctx.Context(), req, opts) + handler := func(msg *client.Message) { + switch msg.Type { + case client.MessageTypeConnectionError: + updater.Error(formatUpstreamServiceError(msg.Err)) + updater.Done() + case client.MessageTypeError: + data, _ := json.Marshal(msg.Payload) + updater.Error(data) + updater.Done() + case client.MessageTypeData: + data, err := json.Marshal(msg.Payload) + if err != nil { + updater.Error(formatSubscriptionError(err)) + updater.Done() + return + } + updater.Update(data) + case client.MessageTypeComplete: + updater.Complete() + updater.Done() + } + } + + cancel, err := c.client.Subscribe(ctx.Context(), req, opts, handler) if err != nil { if isUpstreamError(err) { updater.Error(formatUpstreamServiceError(err)) @@ -158,57 +180,14 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu return err } - go c.readLoop(ctx.Context(), msgCh, cancel, updater) + context.AfterFunc(ctx.Context(), func() { + cancel() + updater.Done() + }) return nil } -// readLoop bridges the channel-based API to the callback-based updater. -func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *client.Message, cancel func(), updater resolve.SubscriptionUpdater) { - defer cancel() - - for { - select { - case <-ctx.Done(): - updater.Done() - return - - case msg, ok := <-msgCh: - if !ok { - updater.Done() - return - } - - switch msg.Type { - case client.MessageTypeConnectionError: - updater.Error(formatUpstreamServiceError(msg.Err)) - updater.Done() - return - - case client.MessageTypeError: - data, _ := json.Marshal(msg.Payload) - updater.Error(data) - updater.Done() - return - - case client.MessageTypeData: - data, err := json.Marshal(msg.Payload) - if err != nil { - updater.Error(formatSubscriptionError(err)) - updater.Done() - return - } - updater.Update(data) - - case client.MessageTypeComplete: - updater.Complete() - updater.Done() - return - } - } - } -} - // isUpstreamError reports whether err is a connection-level upstream error // that should be reported to the client as an UPSTREAM_SERVICE_ERROR. func isUpstreamError(err error) bool { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 1e3be99306..1b4a209b7d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2,6 +2,7 @@ package graphql_datasource import ( "context" + "encoding/json" "errors" "testing" @@ -43,15 +44,37 @@ func (t *testBridgeUpdater) Subscriptions() map[context.Context]resolve.Subscrip return map[context.Context]resolve.SubscriptionIdentifier{} } -func TestReadLoopErrorHandling(t *testing.T) { +func TestHandlerDeliversCorrectMessageForEachType(t *testing.T) { + buildHandler := func(updater *testBridgeUpdater) client.Handler { + return func(msg *client.Message) { + switch msg.Type { + case client.MessageTypeConnectionError: + updater.Error(formatUpstreamServiceError(msg.Err)) + updater.Done() + case client.MessageTypeError: + data, _ := json.Marshal(msg.Payload) + updater.Error(data) + updater.Done() + case client.MessageTypeData: + data, err := json.Marshal(msg.Payload) + if err != nil { + updater.Error(formatSubscriptionError(err)) + updater.Done() + return + } + updater.Update(data) + case client.MessageTypeComplete: + updater.Complete() + updater.Done() + } + } + } + t.Run("connection errors deliver error and done without updates", func(t *testing.T) { updater := &testBridgeUpdater{} - msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed} - close(msgCh) + handler := buildHandler(updater) - subClient := &subscriptionClientV2{} - subClient.readLoop(context.Background(), msgCh, func() {}, updater) + handler(&client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed}) require.True(t, updater.done) require.Len(t, updater.errors, 1) @@ -61,12 +84,9 @@ func TestReadLoopErrorHandling(t *testing.T) { t.Run("non-connection errors deliver error and done without updates", func(t *testing.T) { updater := &testBridgeUpdater{} - msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: errors.New("validation failed")} - close(msgCh) + handler := buildHandler(updater) - subClient := &subscriptionClientV2{} - subClient.readLoop(context.Background(), msgCh, func() {}, updater) + handler(&client.Message{Type: client.MessageTypeConnectionError, Err: errors.New("validation failed")}) require.True(t, updater.done) require.Len(t, updater.errors, 1) @@ -74,39 +94,11 @@ func TestReadLoopErrorHandling(t *testing.T) { require.False(t, updater.completed) }) - t.Run("context cancellation calls done without complete", func(t *testing.T) { - updater := &testBridgeUpdater{} - ctx, cancel := context.WithCancel(context.Background()) - cancel() - msgCh := make(chan *client.Message) - - subClient := &subscriptionClientV2{} - subClient.readLoop(ctx, msgCh, func() {}, updater) - - require.True(t, updater.done) - require.False(t, updater.completed) - }) - - t.Run("channel close calls done without complete", func(t *testing.T) { - updater := &testBridgeUpdater{} - msgCh := make(chan *client.Message) - close(msgCh) - - subClient := &subscriptionClientV2{} - subClient.readLoop(context.Background(), msgCh, func() {}, updater) - - require.True(t, updater.done) - require.False(t, updater.completed) - }) - - t.Run("done message calls complete then done", func(t *testing.T) { + t.Run("complete message calls complete then done", func(t *testing.T) { updater := &testBridgeUpdater{} - msgCh := make(chan *client.Message, 1) - msgCh <- &client.Message{Type: client.MessageTypeComplete} - close(msgCh) + handler := buildHandler(updater) - subClient := &subscriptionClientV2{} - subClient.readLoop(context.Background(), msgCh, func() {}, updater) + handler(&client.Message{Type: client.MessageTypeComplete}) require.True(t, updater.done) require.True(t, updater.completed) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index 5fa58ee0c3..f4b3619d5f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -74,23 +74,15 @@ func New(ctx context.Context, cfg Config) *Client { } // Subscribe creates a new upstream via the appropriate transport. -func (c *Client) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { +func (c *Client) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { if c.ctx.Err() != nil { - return nil, nil, ErrClientClosed + return nil, ErrClientClosed } - // Route to transport - var source <-chan *common.Message - var cancel func() - var err error - if opts.Transport == common.TransportSSE { - source, cancel, err = c.sse.Subscribe(ctx, req, opts) - } else { - source, cancel, err = c.ws.Subscribe(ctx, req, opts) + return c.sse.Subscribe(ctx, req, opts, handler) } - - return source, cancel, err + return c.ws.Subscribe(ctx, req, opts, handler) } // Stats returns client statistics. diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go index c42b508e87..be71af528b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go @@ -38,9 +38,9 @@ func TestClient(t *testing.T) { c := New(ctx, Config{}) cancel() - _, _, err := c.Subscribe(t.Context(), &Request{Query: "subscription { a }"}, Options{ + _, err := c.Subscribe(t.Context(), &Request{Query: "subscription { a }"}, Options{ Endpoint: "ws://localhost/graphql", - }) + }, func(_ *common.Message) {}) assert.Equal(t, ErrClientClosed, err) }) @@ -57,11 +57,14 @@ func TestClient_SubscriberDrain(t *testing.T) { c := New(t.Context(), Config{}) - ch, subCancel, err := c.Subscribe(context.Background(), &common.Request{ + ch := make(chan *common.Message, 1) + subCancel, err := c.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, + }, func(msg *common.Message) { + ch <- msg }) require.NoError(t, err) @@ -92,12 +95,15 @@ func TestClient_SubscriberDrain(t *testing.T) { // Start subscriptions with different headers (forces multiple connections) for i := range 3 { headers := http.Header{"X-Request-ID": []string{string(rune('A' + i))}} - ch, subCancel, err := c.Subscribe(context.Background(), &common.Request{ + ch := make(chan *common.Message, 1) + subCancel, err := c.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, Headers: headers, + }, func(msg *common.Message) { + ch <- msg }) require.NoError(t, err) cancels[i] = subCancel @@ -185,11 +191,14 @@ func TestClient_CancelSendsComplete(t *testing.T) { c := New(t.Context(), Config{}) - ch, cancel, err := c.Subscribe(t.Context(), &common.Request{ + ch := make(chan *common.Message, 1) + cancel, err := c.Subscribe(t.Context(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, + }, func(msg *common.Message) { + ch <- msg }) require.NoError(t, err) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go index 8500f09669..fc121bc509 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -29,6 +29,10 @@ type Message struct { Err error // only set when Type == MessageTypeConnectionError } +// Handler receives subscription messages. It is called synchronously on the +// transport's read goroutine; a slow handler blocks message delivery. +type Handler func(msg *Message) + type ExecutionResult struct { Data json.RawMessage `json:"data,omitempty"` Errors json.RawMessage `json:"errors,omitempty"` diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go index 44e8f788a5..0d24753b3e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -14,6 +14,7 @@ type ( ExecutionResult = common.ExecutionResult Request = common.Request Options = common.Options + Handler = common.Handler TransportType = common.TransportType WSSubprotocol = common.WSSubprotocol SSEMethod = common.SSEMethod diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go index 59697b1490..4ba2f48174 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go @@ -3,7 +3,6 @@ package transport import ( "bytes" "encoding/json" - "io" "net/http" "sync/atomic" @@ -19,21 +18,21 @@ var ( // sseConnection handles a single SSE subscription stream. type sseConnection struct { - resp *http.Response - ch chan *common.Message - done chan struct{} - closed atomic.Bool + resp *http.Response + handler common.Handler + closed atomic.Bool } -func newSSEConnection(resp *http.Response) *sseConnection { +func newSSEConnection(resp *http.Response, handler common.Handler) *sseConnection { return &sseConnection{ - resp: resp, - ch: make(chan *common.Message, 8), - done: make(chan struct{}), + resp: resp, + handler: handler, } } -// readLoop reads SSE events from the response body and sends them to the channel. +// readLoop reads SSE events from the response body and delivers them to the handler. +// Every exit path delivers a terminal message to the handler unless the connection +// was closed by the consumer. func (c *sseConnection) readLoop() { defer c.cleanup() @@ -46,9 +45,10 @@ func (c *sseConnection) readLoop() { eventBytes, err := reader.ReadEvent() if err != nil { - if err != io.EOF { - c.sendError(err) + if c.closed.Load() { + return } + c.sendError(err) return } @@ -65,11 +65,7 @@ func (c *sseConnection) readLoop() { if c.closed.Load() { return } - select { - case c.ch <- msg: - case <-c.done: - return - } + c.handler(msg) if msg.Type.IsTerminal() { return @@ -160,17 +156,13 @@ func (c *sseConnection) sendError(err error) { if c.closed.Load() { return } - select { - case c.ch <- &common.Message{Type: common.MessageTypeConnectionError, Err: err}: - case <-c.done: - } + c.handler(&common.Message{Type: common.MessageTypeConnectionError, Err: err}) } func (c *sseConnection) cleanup() { c.closed.Store(true) c.resp.Body.Close() - close(c.ch) // Close channel so fanout exits } // closeConn terminates the SSE connection. @@ -179,6 +171,5 @@ func (c *sseConnection) closeConn() { return } - close(c.done) c.resp.Body.Close() } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go index 0d92981a34..147ec5e182 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go @@ -20,39 +20,39 @@ func TestSSEConnection_ReadLoop(t *testing.T) { "event: next\ndata: {\"data\":{\"time\":\"12:00\"}}\n\n", )) resp := &http.Response{Body: body} - conn := newSSEConnection(resp) + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) go conn.readLoop() - msg := <-conn.ch + msg := receive(t, 1*time.Second) require.NotNil(t, msg.Payload) // Data field contains the raw "data" value from GraphQL response assert.JSONEq(t, `{"time":"12:00"}`, string(msg.Payload.Data)) }) - t.Run("closes channel on EOF", func(t *testing.T) { + t.Run("delivers connection error on EOF", func(t *testing.T) { body := io.NopCloser(strings.NewReader("")) resp := &http.Response{Body: body} - conn := newSSEConnection(resp) + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) go conn.readLoop() - select { - case _, ok := <-conn.ch: - assert.False(t, ok, "channel should be closed") - case <-time.After(100 * time.Millisecond): - t.Fatal("channel not closed on EOF") - } + msg := receive(t, 1*time.Second) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + assert.Error(t, msg.Err) }) t.Run("sends error on read failure", func(t *testing.T) { body := &errorReader{err: io.ErrUnexpectedEOF} resp := &http.Response{Body: io.NopCloser(body)} - conn := newSSEConnection(resp) + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) go conn.readLoop() - msg := <-conn.ch + msg := receive(t, 1*time.Second) require.Error(t, msg.Err) assert.Equal(t, common.MessageTypeConnectionError, msg.Type) }) @@ -64,74 +64,57 @@ func TestSSEConnection_ReadLoop(t *testing.T) { "event: next\ndata: {\"data\":{}}\n\n", // Should not receive this )) resp := &http.Response{Body: body} - conn := newSSEConnection(resp) + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) go conn.readLoop() + var messages []*common.Message + // First message - msg1 := <-conn.ch + msg1 := receive(t, 1*time.Second) + messages = append(messages, msg1) assert.NotNil(t, msg1.Payload) assert.Equal(t, common.MessageTypeData, msg1.Type) // Complete message - msg2 := <-conn.ch + msg2 := receive(t, 1*time.Second) + messages = append(messages, msg2) assert.Equal(t, common.MessageTypeComplete, msg2.Type) - // Channel should close, no third message - select { - case _, ok := <-conn.ch: - assert.False(t, ok, "channel should be closed after complete") - case <-time.After(100 * time.Millisecond): - t.Fatal("channel not closed after complete") - } + assert.Len(t, messages, 2, "should receive exactly 2 messages before stopping") }) } func TestSSEConnection_Close(t *testing.T) { - t.Run("closes channel and body", func(t *testing.T) { + t.Run("closes body", func(t *testing.T) { pr, pw := io.Pipe() body := &trackingCloser{Reader: pr} resp := &http.Response{Body: body} - conn := newSSEConnection(resp) + handler, _ := collectingHandler() + conn := newSSEConnection(resp, handler) go conn.readLoop() conn.closeConn() pw.Close() // Ensure pipe is fully closed - // Channel close signals cleanup completed - select { - case _, ok := <-conn.ch: - require.False(t, ok, "channel should be closed") - case <-time.After(100 * time.Millisecond): - t.Fatal("channel should be closed (timeout)") - } - - assert.True(t, body.closed.Load(), "body should be closed") + assert.Eventually(t, func() bool { + return body.closed.Load() + }, 1*time.Second, 10*time.Millisecond, "body should be closed") }) t.Run("is idempotent", func(t *testing.T) { body := io.NopCloser(strings.NewReader("")) resp := &http.Response{Body: body} - conn := newSSEConnection(resp) + handler, _ := collectingHandler() + conn := newSSEConnection(resp, handler) conn.closeConn() conn.closeConn() // second call is a no-op }) } -func TestSSEConnection_Channel(t *testing.T) { - t.Run("returns buffered channel", func(t *testing.T) { - body := io.NopCloser(strings.NewReader("")) - resp := &http.Response{Body: body} - conn := newSSEConnection(resp) - - ch := conn.ch - assert.NotNil(t, ch) - assert.Equal(t, 8, cap(ch)) - }) -} - // errorReader always returns an error type errorReader struct { err error diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 4080d9367b..5e672ebccd 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -56,7 +56,7 @@ func NewSSETransport(ctx context.Context, client *http.Client, log abstractlogge // The HTTP method is determined by opts.SSEMethod: // - SSEMethodAuto or SSEMethodPOST: POST with JSON body (graphql-sse spec) // - SSEMethodGET: GET with query parameters (traditional SSE) -func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { +func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { var httpReq *http.Request var err error @@ -83,11 +83,11 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts case common.SSEMethodGET: httpReq, err = buildGETRequest(requestCtx, req, opts) default: - return nil, nil, fmt.Errorf("unsupported SSE method: %s", method) + return nil, fmt.Errorf("unsupported SSE method: %s", method) } if err != nil { - return nil, nil, err + return nil, err } // Execute request @@ -97,7 +97,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.Error(err), ) - return nil, nil, fmt.Errorf("execute request: %w", err) + return nil, fmt.Errorf("execute request: %w", err) } if resp.StatusCode != http.StatusOK { @@ -108,15 +108,15 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts abstractlogger.Int("status", resp.StatusCode), ) if len(body) > 0 { - return nil, nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) } - return nil, nil, fmt.Errorf("unexpected status: %d", resp.StatusCode) + return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode) } // Verify content type (should be text/event-stream) if err := t.validateContentType(resp); err != nil { resp.Body.Close() - return nil, nil, err + return nil, err } t.log.Debug("sseTransport.Subscribe", @@ -125,7 +125,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts ) // Create connection - conn := newSSEConnection(resp) + conn := newSSEConnection(resp, handler) t.mu.Lock() t.conns[conn] = struct{}{} @@ -138,7 +138,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts t.removeConn(conn) } - return conn.ch, cancelFn, nil + return cancelFn, nil } // buildPOSTRequest creates a POST request with JSON body (graphql-sse spec). diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go index 430767c746..6d8514ac68 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -54,14 +54,15 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", Variables: []byte(`{"id": 123}`), OperationName: "TestSub", }, common.Options{ Endpoint: server.URL, Transport: common.TransportSSE, - }) + }, handler) require.NoError(t, err) defer cancel() @@ -71,12 +72,12 @@ func TestSSETransport_Subscribe(t *testing.T) { assert.Equal(t, "TestSub", receivedBody["operationName"]) // Receive data message - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "42") // Receive complete message - msg = receiveWithTimeout(t, ch, time.Second) + msg = receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -101,16 +102,17 @@ func TestSSETransport_Subscribe(t *testing.T) { "X-Custom-Header": []string{"custom-value"}, } - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Headers: headers, - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, "Bearer token123", receivedAuth) assert.Equal(t, "custom-value", receivedCustom) @@ -133,13 +135,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { user { name } }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "Alice") assert.Equal(t, common.MessageTypeData, msg.Type) @@ -157,13 +160,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Equal(t, common.MessageTypeError, msg.Type) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Errors), "Something went wrong") @@ -181,13 +185,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) assert.Nil(t, msg.Err) assert.Nil(t, msg.Payload) @@ -215,13 +220,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) // The multi-line data is joined with newlines assert.Contains(t, string(msg.Payload.Data), "42") @@ -230,7 +236,6 @@ func TestSSETransport_Subscribe(t *testing.T) { t.Run("ignores SSE comments", func(t *testing.T) { t.Parallel() - var messageCount atomic.Int32 server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) @@ -254,21 +259,18 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, _ := collectingHandler() + wrappedHandler, collect := waitForMessages(handler) + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, wrappedHandler) require.NoError(t, err) defer cancel() - // Should only receive 2 messages (next + complete), not comments - for msg := range ch { - messageCount.Add(1) - if msg.Type.IsTerminal() { - break - } - } + msgs := collect(time.Second) - assert.Equal(t, int32(2), messageCount.Load()) + // Should only receive 2 messages (next + complete), not comments + assert.Len(t, msgs, 2) }) t.Run("cancel closes connection", func(t *testing.T) { @@ -290,13 +292,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) // Receive first message - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, 1, tr.ConnCount()) @@ -310,7 +313,9 @@ func TestSSETransport_Subscribe(t *testing.T) { t.Fatal("server did not detect disconnect") } - assert.Equal(t, 0, tr.ConnCount()) + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) }) t.Run("context cancellation stops subscription", func(t *testing.T) { @@ -333,13 +338,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(transportCtx, http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(transportCtx, &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(transportCtx, &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - _ = receiveWithTimeout(t, ch, time.Second) + _ = receive(t, time.Second) // Cancel context transportCancel() @@ -360,9 +366,9 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - _, _, err := tr.Subscribe(context.Background(), &common.Request{ + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) require.Error(t, err) assert.Contains(t, err.Error(), "401") @@ -378,9 +384,9 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - _, _, err := tr.Subscribe(context.Background(), &common.Request{ + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) require.Error(t, err) assert.Contains(t, err.Error(), "500") @@ -408,14 +414,16 @@ func TestSSETransport_Subscribe(t *testing.T) { opts := common.Options{Endpoint: server.URL} - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) require.NoError(t, err) - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) require.NoError(t, err) - receiveWithTimeout(t, ch1, time.Second) - receiveWithTimeout(t, ch2, time.Second) + receive1(t, time.Second) + receive2(t, time.Second) // SSE creates separate HTTP requests (no multiplexing) assert.Equal(t, int32(2), reqCount.Load()) @@ -441,22 +449,21 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.NotNil(t, msg.Payload) - // Channel should close when server closes stream - select { - case _, ok := <-ch: - assert.False(t, ok, "channel should be closed") - case <-time.After(time.Second): - t.Fatal("channel should have been closed") - } + // Server closes without sending complete — this is an abnormal + // disconnection per graphql-sse protocol, delivered as a connection error. + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + assert.Error(t, msg.Err) }) t.Run("handles data without event type", func(t *testing.T) { @@ -477,13 +484,14 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "99") }) @@ -513,14 +521,16 @@ func TestSSETransport_ContextCancellation(t *testing.T) { opts := common.Options{Endpoint: server.URL} - ch1, _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + handler1, receive1 := collectingHandler() + _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) require.NoError(t, err) - ch2, _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + handler2, receive2 := collectingHandler() + _, err = tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) require.NoError(t, err) - receiveWithTimeout(t, ch1, time.Second) - receiveWithTimeout(t, ch2, time.Second) + receive1(t, time.Second) + receive2(t, time.Second) assert.Equal(t, 2, tr.ConnCount()) @@ -561,13 +571,14 @@ func TestSSETransport_CustomClient(t *testing.T) { tr := NewSSETransport(t.Context(), customClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, "test-client", customHeaderReceived) }) @@ -609,13 +620,14 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -630,9 +642,9 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - _, _, err := tr.Subscribe(context.Background(), &common.Request{ + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}) + }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) require.Error(t, err) assert.True(t, strings.Contains(err.Error(), "content-type") || strings.Contains(err.Error(), "Content-Type")) @@ -677,14 +689,15 @@ func TestSSETransport_GETMethod(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", Variables: []byte(`{"id": 123}`), OperationName: "TestSub", }, common.Options{ Endpoint: server.URL, SSEMethod: common.SSEMethodGET, - }) + }, handler) require.NoError(t, err) defer cancel() @@ -695,12 +708,12 @@ func TestSSETransport_GETMethod(t *testing.T) { assert.Equal(t, "TestSub", receivedOperationName) // Receive data message - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "42") // Receive complete message - msg = receiveWithTimeout(t, ch, time.Second) + msg = receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -720,16 +733,17 @@ func TestSSETransport_GETMethod(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL + "?token=abc123", SSEMethod: common.SSEMethodGET, - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, "abc123", receivedToken) assert.Equal(t, "subscription { test }", receivedQuery) @@ -753,17 +767,18 @@ func TestSSETransport_GETMethod(t *testing.T) { "Authorization": []string{"Bearer token123"}, } - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, SSEMethod: common.SSEMethodGET, Headers: headers, - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, "Bearer token123", receivedAuth) }) @@ -784,17 +799,18 @@ func TestSSETransport_GETMethod(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", // No variables or operationName }, common.Options{ Endpoint: server.URL, SSEMethod: common.SSEMethodGET, - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.False(t, hasVariables, "variables should not be in query params") assert.False(t, hasOperationName, "operationName should not be in query params") @@ -818,16 +834,17 @@ func TestSSETransport_MethodDefault(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, SSEMethod: common.SSEMethodAuto, // or just omit it - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, http.MethodPost, receivedMethod) }) @@ -846,16 +863,17 @@ func TestSSETransport_MethodDefault(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, SSEMethod: common.SSEMethodPOST, - }) + }, handler) require.NoError(t, err) defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) assert.Equal(t, http.MethodPost, receivedMethod) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go index 530f46c96e..f336dfe522 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go @@ -1,19 +1,61 @@ package transport import ( + "sync" "testing" "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) -func receiveWithTimeout(t *testing.T, ch <-chan *common.Message, timeout time.Duration) *common.Message { - t.Helper() - select { - case msg := <-ch: - return msg - case <-time.After(timeout): - t.Fatal("timeout waiting for message") - return nil +// collectingHandler returns a handler that appends messages to a channel, +// plus a helper to receive with timeout (for use in tests). +func collectingHandler() (common.Handler, func(t *testing.T, timeout time.Duration) *common.Message) { + ch := make(chan *common.Message, 64) + handler := func(msg *common.Message) { + ch <- msg } + receive := func(t *testing.T, timeout time.Duration) *common.Message { + t.Helper() + select { + case msg := <-ch: + return msg + case <-time.After(timeout): + t.Fatal("timeout waiting for message") + return nil + } + } + return handler, receive +} + +// waitForMessages collects messages from a handler until a terminal message or timeout. +func waitForMessages(handler common.Handler) (common.Handler, func(timeout time.Duration) []*common.Message) { + var mu sync.Mutex + var msgs []*common.Message + done := make(chan struct{}, 1) + + wrappedHandler := func(msg *common.Message) { + mu.Lock() + msgs = append(msgs, msg) + mu.Unlock() + handler(msg) + if msg.Type.IsTerminal() { + select { + case done <- struct{}{}: + default: + } + } + } + + collect := func(timeout time.Duration) []*common.Message { + select { + case <-done: + case <-time.After(timeout): + } + mu.Lock() + defer mu.Unlock() + return append([]*common.Message{}, msgs...) + } + + return wrappedHandler, collect } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 3251d2461a..672fc2fa5e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -67,7 +67,7 @@ type wsConnection struct { ctx context.Context subsMu sync.RWMutex - subs map[string]chan<- *common.Message + subs map[string]common.Handler closed atomic.Bool @@ -98,7 +98,7 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts ...wsCo log: o.logger, cancel: cancel, ctx: ctx, - subs: make(map[string]chan<- *common.Message), + subs: make(map[string]common.Handler), onEmpty: o.onEmpty, writeTimeout: o.writeTimeout, @@ -109,22 +109,19 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts ...wsCo return c } -func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request) (<-chan *common.Message, func(), error) { +func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request, handler common.Handler) (func(), error) { if c.closed.Load() { - return nil, nil, common.ErrConnectionClosed + return nil, common.ErrConnectionClosed } - // Small buffer to absorb bursts - ch := make(chan *common.Message, 8) - c.subsMu.Lock() if _, exists := c.subs[id]; exists { c.subsMu.Unlock() - return nil, nil, ErrSubscriptionExists + return nil, ErrSubscriptionExists } - c.subs[id] = ch + c.subs[id] = handler c.subsMu.Unlock() if err := c.protocol.Subscribe(ctx, c.conn, id, req); err != nil { @@ -133,7 +130,7 @@ func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Req abstractlogger.Error(err), ) c.removeSub(id) - return nil, nil, err + return nil, err } c.log.Debug("wsConnection.Subscribe", @@ -143,20 +140,15 @@ func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Req cancel := func() { c.unsubscribe(id) } - return ch, cancel, nil + return cancel, nil } func (c *wsConnection) removeSub(id string) { c.subsMu.Lock() - ch, exists := c.subs[id] delete(c.subs, id) isEmpty := len(c.subs) == 0 c.subsMu.Unlock() - if exists { - close(ch) - } - if isEmpty { c.closeConn() } @@ -216,17 +208,17 @@ func (c *wsConnection) readLoop() { func (c *wsConnection) dispatch(msg *protocol.Message) { c.subsMu.RLock() - ch, exists := c.subs[msg.ID] + handler, exists := c.subs[msg.ID] c.subsMu.RUnlock() if !exists { return } - ch <- msg.IntoClientMessage() + handler(msg.IntoClientMessage()) if msg.Type == protocol.MessageComplete || msg.Type == protocol.MessageError { - c.unsubscribe(msg.ID) + c.removeSub(msg.ID) } } @@ -243,17 +235,12 @@ func (c *wsConnection) shutdown(err error) { c.subsMu.Lock() subs := c.subs - c.subs = make(map[string]chan<- *common.Message) + c.subs = make(map[string]common.Handler) c.subsMu.Unlock() errMsg := &common.Message{Type: common.MessageTypeConnectionError, Err: err} - for _, ch := range subs { - select { - case ch <- errMsg: - case <-time.After(100 * time.Millisecond): - // dead consumer - } - close(ch) + for _, handler := range subs { + handler(errMsg) } // Cancel after dispatching errors so readLoop consumers still have a live diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 21886942c8..69d2fa0659 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -21,20 +21,20 @@ import ( func TestWSConnection_Subscribe(t *testing.T) { t.Parallel() - t.Run("returns channel and calls protocol subscribe", func(t *testing.T) { + t.Run("calls protocol subscribe and handler can receive", func(t *testing.T) { t.Parallel() conn, _ := newTestConn(t) proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ + handler, _ := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ Query: "subscription { test }", - }) + }, handler) defer cancel() require.NoError(t, err) - assert.NotNil(t, ch) assert.Len(t, proto.SubscribeCalls(), 1) assert.Equal(t, "sub-1", proto.SubscribeCalls()[0].ID) }) @@ -46,11 +46,11 @@ func TestWSConnection_Subscribe(t *testing.T) { proto := newMockProtocol() wsc := newWSConnection(conn, proto) - _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) defer cancel() - _, _, err = wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + _, err = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) assert.ErrorIs(t, err, ErrSubscriptionExists) }) @@ -63,7 +63,7 @@ func TestWSConnection_Subscribe(t *testing.T) { wsc := newWSConnection(conn, proto) wsc.closeConn() - _, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) assert.ErrorIs(t, err, common.ErrConnectionClosed) }) @@ -76,7 +76,7 @@ func TestWSConnection_Subscribe(t *testing.T) { proto.subscribeErr = assert.AnError wsc := newWSConnection(conn, proto) - _, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) assert.Error(t, err) assert.Equal(t, 0, wsc.subCount(), "failed subscription should not be registered") @@ -86,14 +86,15 @@ func TestWSConnection_Subscribe(t *testing.T) { func TestWSConnection_ReadLoop(t *testing.T) { t.Parallel() - t.Run("dispatches data message to subscription channel", func(t *testing.T) { + t.Run("dispatches data message to subscription handler", func(t *testing.T) { t.Parallel() conn, _ := newTestConn(t) proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) require.NoError(t, err) defer cancel() @@ -105,19 +106,20 @@ func TestWSConnection_ReadLoop(t *testing.T) { Payload: &common.ExecutionResult{Data: json.RawMessage(`{"value": 42}`)}, }) - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) require.NotNil(t, msg.Payload) assert.Contains(t, string(msg.Payload.Data), "42") }) - t.Run("closes channel on complete message", func(t *testing.T) { + t.Run("delivers complete message to handler", func(t *testing.T) { t.Parallel() conn, _ := newTestConn(t) proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch, _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + handler, receive := collectingHandler() + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) require.NoError(t, err) go wsc.readLoop() @@ -127,11 +129,8 @@ func TestWSConnection_ReadLoop(t *testing.T) { Type: protocol.MessageComplete, }) - // Consume the message (blocking send requires consumer) - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) - - assertChannelClosed(t, ch) }) t.Run("responds to ping with pong", func(t *testing.T) { @@ -157,7 +156,8 @@ func TestWSConnection_ReadLoop(t *testing.T) { proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) require.NoError(t, err) defer cancel() @@ -175,7 +175,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { Payload: &common.ExecutionResult{Data: json.RawMessage(`{"right": true}`)}, }) - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Contains(t, string(msg.Payload.Data), "right") }) } @@ -183,21 +183,21 @@ func TestWSConnection_ReadLoop(t *testing.T) { func TestWSConnection_Unsubscribe(t *testing.T) { t.Parallel() - t.Run("calls protocol unsubscribe and closes channel", func(t *testing.T) { + t.Run("calls protocol unsubscribe and removes subscription", func(t *testing.T) { t.Parallel() conn, _ := newTestConn(t) proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) cancel() assert.Len(t, proto.UnsubscribeCalls(), 1) assert.Equal(t, "sub-1", proto.UnsubscribeCalls()[0]) - assertChannelClosed(t, ch) + assert.Equal(t, 0, wsc.subCount()) }) t.Run("is idempotent", func(t *testing.T) { @@ -207,7 +207,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { proto := newMockProtocol() wsc := newWSConnection(conn, proto) - _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) cancel() @@ -227,7 +227,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { withConnWriteTimeout(50*time.Millisecond), ) - _, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) start := time.Now() @@ -252,7 +252,7 @@ func TestWSConnection_OnEmpty(t *testing.T) { withOnEmpty(func() { emptyCalled <- struct{}{} }), ) - _, cancel, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) cancel() select { @@ -276,8 +276,8 @@ func TestWSConnection_OnEmpty(t *testing.T) { withOnEmpty(func() { emptyCalled <- struct{}{} }), ) - _, cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) - _, cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) cancel1() @@ -354,22 +354,21 @@ func TestWSConnection_Close(t *testing.T) { proto := newMockProtocol() wsc := newWSConnection(conn, proto) - ch1, _, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) - ch2, _, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + handler1, receive1 := collectingHandler() + _, _ = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler1) + + handler2, receive2 := collectingHandler() + _, _ = wsc.subscribe(context.Background(), "sub-2", &common.Request{}, handler2) wsc.closeConn() - // Consume messages (blocking send requires consumer) - msg1 := receiveWithTimeout(t, ch1, 100*time.Millisecond) + msg1 := receive1(t, 100*time.Millisecond) assert.Error(t, msg1.Err) assert.Equal(t, common.MessageTypeConnectionError, msg1.Type) - msg2 := receiveWithTimeout(t, ch2, 100*time.Millisecond) + msg2 := receive2(t, 100*time.Millisecond) assert.Error(t, msg2.Err) assert.Equal(t, common.MessageTypeConnectionError, msg2.Type) - - assertChannelClosed(t, ch1) - assertChannelClosed(t, ch2) }) t.Run("is idempotent", func(t *testing.T) { @@ -399,10 +398,10 @@ func TestWSConnection_SubCount(t *testing.T) { assert.Equal(t, 0, wsc.subCount()) - _, cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) assert.Equal(t, 1, wsc.subCount()) - _, cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}) + cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) assert.Equal(t, 2, wsc.subCount()) cancel1() @@ -426,7 +425,8 @@ func TestWSConnection_WriteTimeout(t *testing.T) { withConnWriteTimeout(50*time.Millisecond), ) - ch, cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}) + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) require.NoError(t, err) defer cancel() @@ -444,7 +444,7 @@ func TestWSConnection_WriteTimeout(t *testing.T) { // Should receive data within timeout + small buffer // If pong blocked for 500ms, this would timeout - msg := receiveWithTimeout(t, ch, 150*time.Millisecond) + msg := receive(t, 150*time.Millisecond) assert.NotNil(t, msg.Payload) }) } @@ -536,16 +536,6 @@ func newTestConn(t *testing.T) (*websocket.Conn, *websocket.Conn) { return clientConn, srvConn } -func assertChannelClosed(t *testing.T, ch <-chan *common.Message) { - t.Helper() - select { - case _, ok := <-ch: - assert.False(t, ok, "channel should be closed") - case <-time.After(100 * time.Millisecond): - t.Error("timeout waiting for channel to close") - } -} - // mockProtocol implements protocol.Protocol for testing. type mockProtocol struct { mu sync.Mutex diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index e4389fd407..143c975c2c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -169,14 +169,14 @@ func NewWSTransport(ctx context.Context, opts ...WSTransportOption) *WSTransport return t } -func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options) (<-chan *common.Message, func(), error) { +func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { conn, err := t.getOrDial(ctx, opts) if err != nil { - return nil, nil, err + return nil, err } id := xid.New().String() - return conn.subscribe(ctx, id, req) + return conn.subscribe(ctx, id, req, handler) } // pingLoop sends periodic pings to all active connections and shuts down diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 4b62d4add2..7f8e5227a4 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -45,19 +45,20 @@ func TestWSTransport_Subscribe(t *testing.T) { tr := NewWSTransport(t.Context()) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Contains(t, string(msg.Payload.Data), "42") - msg = receiveWithTimeout(t, ch, time.Second) + msg = receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -91,17 +92,19 @@ func TestWSTransport_Subscribe(t *testing.T) { Transport: common.TransportWS, } - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) require.NoError(t, err) defer cancel1() - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) require.NoError(t, err) defer cancel2() // Both should receive messages - receiveWithTimeout(t, ch1, time.Second) - receiveWithTimeout(t, ch2, time.Second) + receive1(t, time.Second) + receive2(t, time.Second) // Only one connection should have been made assert.Equal(t, int32(1), dialCount.Load()) @@ -136,24 +139,26 @@ func TestWSTransport_Subscribe(t *testing.T) { headers1 := http.Header{"Authorization": []string{"Bearer token1"}} headers2 := http.Header{"Authorization": []string{"Bearer token2"}} - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, Headers: headers1, - }) + }, handler1) require.NoError(t, err) defer cancel1() - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, Headers: headers2, - }) + }, handler2) require.NoError(t, err) defer cancel2() - receiveWithTimeout(t, ch1, time.Second) - receiveWithTimeout(t, ch2, time.Second) + receive1(t, time.Second) + receive2(t, time.Second) // Two connections due to different headers assert.Equal(t, int32(2), dialCount.Load()) @@ -185,24 +190,26 @@ func TestWSTransport_Subscribe(t *testing.T) { tr := NewWSTransport(t.Context()) - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, InitPayload: map[string]any{"token": "abc"}, - }) + }, handler1) require.NoError(t, err) defer cancel1() - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, InitPayload: map[string]any{"token": "xyz"}, - }) + }, handler2) require.NoError(t, err) defer cancel2() - receiveWithTimeout(t, ch1, time.Second) - receiveWithTimeout(t, ch2, time.Second) + receive1(t, time.Second) + receive2(t, time.Second) // Two connections due to different init payload assert.Equal(t, int32(2), dialCount.Load()) @@ -228,10 +235,10 @@ func TestWSTransport_Subscribe(t *testing.T) { Transport: common.TransportWS, } - _, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, func(_ *common.Message) {}) require.NoError(t, err) - _, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, func(_ *common.Message) {}) require.NoError(t, err) assert.Equal(t, 1, tr.ConnCount()) @@ -278,9 +285,10 @@ func TestWSTransport_Subscribe(t *testing.T) { } // First subscription - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) require.NoError(t, err) - receiveWithTimeout(t, ch1, time.Second) + receive1(t, time.Second) cancel1() // Wait for connection to be removed @@ -289,10 +297,11 @@ func TestWSTransport_Subscribe(t *testing.T) { }, time.Second, 10*time.Millisecond) // Second subscription should redial - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) require.NoError(t, err) defer cancel2() - receiveWithTimeout(t, ch2, time.Second) + receive2(t, time.Second) assert.Equal(t, int32(2), dialCount.Load()) }) @@ -315,10 +324,10 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { tr := NewWSTransport(t.Context()) - _, cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, func(_ *common.Message) {}) require.NoError(t, err) assert.Equal(t, 1, tr.ConnCount()) @@ -346,10 +355,10 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { opts := common.Options{Endpoint: server.URL, Transport: common.TransportWS} - _, cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts) + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, func(_ *common.Message) {}) require.NoError(t, err) - _, cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts) + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, func(_ *common.Message) {}) require.NoError(t, err) assert.Equal(t, 1, tr.ConnCount()) @@ -402,13 +411,14 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { var wg sync.WaitGroup for range 10 { wg.Go(func() { - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { test }"}, opts) + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { test }"}, opts, handler) if err != nil { return } defer cancel() - receiveWithTimeout(t, ch, time.Second) + receive(t, time.Second) }) } @@ -477,14 +487,15 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { }, } - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLTransportWS, InitPayload: initPayload, - }) + }, handler) require.NoError(t, err) defer cancel() @@ -502,7 +513,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { } // Subscription should work - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.NotNil(t, msg.Payload) }) @@ -558,14 +569,15 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { "version": float64(2), // JSON numbers are float64 } - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLWS, // Legacy protocol InitPayload: initPayload, - }) + }, handler) require.NoError(t, err) defer cancel() @@ -580,7 +592,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { } // Subscription should work - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.NotNil(t, msg.Payload) }) @@ -631,14 +643,15 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { tr := NewWSTransport(t.Context()) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLTransportWS, InitPayload: nil, // No init payload - }) + }, handler) require.NoError(t, err) defer cancel() @@ -651,7 +664,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { } // Subscription should still work - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.NotNil(t, msg.Payload) }) @@ -710,32 +723,34 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { tr := NewWSTransport(t.Context()) // First subscription with user1 token - ch1, cancel1, err := tr.Subscribe(context.Background(), &common.Request{ + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLTransportWS, InitPayload: map[string]any{"user": "user1"}, - }) + }, handler1) require.NoError(t, err) defer cancel1() - receiveWithTimeout(t, ch1, time.Second) + receive1(t, time.Second) // Second subscription with user2 token - should create new connection - ch2, cancel2, err := tr.Subscribe(context.Background(), &common.Request{ + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLTransportWS, InitPayload: map[string]any{"user": "user2"}, - }) + }, handler2) require.NoError(t, err) defer cancel2() - receiveWithTimeout(t, ch2, time.Second) + receive2(t, time.Second) // Verify two separate connections were made with different payloads assert.Equal(t, 2, tr.ConnCount()) @@ -782,20 +797,21 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { tr := NewWSTransport(t.Context()) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLWS, // Request legacy protocol - }) + }, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Contains(t, string(msg.Payload.Data), "42") - msg = receiveWithTimeout(t, ch, time.Second) + msg = receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -826,21 +842,22 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { tr := NewWSTransport(t.Context()) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolGraphQLWS, - }) + }, handler) require.NoError(t, err) defer cancel() // Should receive data (keep-alive is handled internally) - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.NotNil(t, msg.Payload) - msg = receiveWithTimeout(t, ch, time.Second) + msg = receive(t, time.Second) assert.Equal(t, common.MessageTypeComplete, msg.Type) }) @@ -866,17 +883,18 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { tr := NewWSTransport(t.Context()) - ch, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, WSSubprotocol: common.SubprotocolAuto, // Auto-negotiate - }) + }, handler) require.NoError(t, err) defer cancel() - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Contains(t, string(msg.Payload.Data), "99") }) } @@ -908,12 +926,12 @@ func TestWSTransport_Heartbeat(t *testing.T) { tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) - _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, func(_ *common.Message) {}) require.NoError(t, err) defer cancel() @@ -942,16 +960,17 @@ func TestWSTransport_Heartbeat(t *testing.T) { tr := NewWSTransport(t.Context(), WithPingInterval(100*time.Millisecond), WithPingTimeout(50*time.Millisecond)) - ch, _, err := tr.Subscribe(context.Background(), &common.Request{ + handler, receive := collectingHandler() + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, handler) require.NoError(t, err) // Connection should be closed due to pong timeout, subscriber gets notified - msg := receiveWithTimeout(t, ch, time.Second) + msg := receive(t, time.Second) assert.Equal(t, common.MessageTypeConnectionError, msg.Type) assert.Error(t, msg.Err) @@ -984,12 +1003,12 @@ func TestWSTransport_Heartbeat(t *testing.T) { // PingInterval set, PingTimeout left at zero (disabled) tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) - _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, func(_ *common.Message) {}) require.NoError(t, err) defer cancel() @@ -1024,12 +1043,12 @@ func TestWSTransport_Heartbeat(t *testing.T) { tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond), WithPingTimeout(200*time.Millisecond)) - _, cancel, err := tr.Subscribe(context.Background(), &common.Request{ + cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ Endpoint: server.URL, Transport: common.TransportWS, - }) + }, func(_ *common.Message) {}) require.NoError(t, err) defer cancel() diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 32230cf07a..310da9d9d0 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -1449,6 +1449,8 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip } type subscriptionUpdater struct { + mu sync.Mutex + done bool debug bool triggerID uint64 resolver *Resolver @@ -1457,7 +1459,9 @@ type subscriptionUpdater struct { } func (s *subscriptionUpdater) Update(data []byte) { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { return } if s.debug { @@ -1467,14 +1471,18 @@ func (s *subscriptionUpdater) Update(data []byte) { } func (s *subscriptionUpdater) Heartbeat() { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { return } s.resolver.heartbeatTriggerSubscriptions(s.triggerID) } func (s *subscriptionUpdater) UpdateSubscription(id SubscriptionIdentifier, data []byte) { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { return } if s.debug { @@ -1488,7 +1496,9 @@ func (s *subscriptionUpdater) Subscriptions() map[context.Context]SubscriptionId } func (s *subscriptionUpdater) Complete() { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:complete:skip:%d\n", s.triggerID) } @@ -1501,7 +1511,9 @@ func (s *subscriptionUpdater) Complete() { } func (s *subscriptionUpdater) Error(data []byte) { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:error:skip:%d\n", s.triggerID) } @@ -1514,6 +1526,12 @@ func (s *subscriptionUpdater) Error(data []byte) { } func (s *subscriptionUpdater) Done() { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + s.done = true if s.debug { fmt.Printf("resolver:subscription_updater:done:%d\n", s.triggerID) } @@ -1521,7 +1539,9 @@ func (s *subscriptionUpdater) Done() { } func (s *subscriptionUpdater) CloseSubscription(id SubscriptionIdentifier) { - if s.ctx.Err() != nil { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:close:skip:%d\n", s.triggerID) } From 2a3e4f68aa610913e96a1747cabd994c3c73f00e Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 13:38:33 +0000 Subject: [PATCH 04/52] add serena config --- .serena/.gitignore | 2 + .serena/project.yml | 138 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 .serena/.gitignore create mode 100644 .serena/project.yml diff --git a/.serena/.gitignore b/.serena/.gitignore new file mode 100644 index 0000000000..2e510aff58 --- /dev/null +++ b/.serena/.gitignore @@ -0,0 +1,2 @@ +/cache +/project.local.yml diff --git a/.serena/project.yml b/.serena/project.yml new file mode 100644 index 0000000000..cfa9063fb9 --- /dev/null +++ b/.serena/project.yml @@ -0,0 +1,138 @@ +# the name by which the project can be referenced within Serena +project_name: "graphql-go-tools" + + +# list of languages for which language servers are started; choose from: +# al bash clojure cpp csharp +# csharp_omnisharp dart elixir elm erlang +# fortran fsharp go groovy haskell +# java julia kotlin lua markdown +# matlab nix pascal perl php +# php_phpactor powershell python python_jedi r +# rego ruby ruby_solargraph rust scala +# swift terraform toml typescript typescript_vts +# vue yaml zig +# (This list may be outdated. For the current list, see values of Language enum here: +# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py +# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.) +# Note: +# - For C, use cpp +# - For JavaScript, use typescript +# - For Free Pascal/Lazarus, use pascal +# Special requirements: +# Some languages require additional setup/installations. +# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers +# When using multiple languages, the first language server that supports a given file will be used for that file. +# The first language is the default language and the respective language server will be used as a fallback. +# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored. +languages: +- go + +# the encoding used by text files in the project +# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings +encoding: "utf-8" + +# line ending convention to use when writing source files. +# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default) +# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings. +line_ending: + +# The language backend to use for this project. +# If not set, the global setting from serena_config.yml is used. +# Valid values: LSP, JetBrains +# Note: the backend is fixed at startup. If a project with a different backend +# is activated post-init, an error will be returned. +language_backend: + +# whether to use project's .gitignore files to ignore files +ignore_all_files_in_gitignore: true + +# list of additional paths to ignore in this project. +# Same syntax as gitignore, so you can use * and **. +# Note: global ignored_paths from serena_config.yml are also applied additively. +ignored_paths: [] + +# whether the project is in read-only mode +# If set to true, all editing tools will be disabled and attempts to use them will result in an error +# Added on 2025-04-18 +read_only: false + +# list of tool names to exclude. +# This extends the existing exclusions (e.g. from the global configuration) +# +# Below is the complete list of tools for convenience. +# To make sure you have the latest list of tools, and to view their descriptions, +# execute `uv run scripts/print_tool_overview.py`. +# +# * `activate_project`: Activates a project by name. +# * `check_onboarding_performed`: Checks whether project onboarding was already performed. +# * `create_text_file`: Creates/overwrites a file in the project directory. +# * `delete_lines`: Deletes a range of lines within a file. +# * `delete_memory`: Deletes a memory from Serena's project-specific memory store. +# * `execute_shell_command`: Executes a shell command. +# * `find_referencing_code_snippets`: Finds code snippets in which the symbol at the given location is referenced. +# * `find_referencing_symbols`: Finds symbols that reference the symbol at the given location (optionally filtered by type). +# * `find_symbol`: Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type). +# * `get_current_config`: Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes. +# * `get_symbols_overview`: Gets an overview of the top-level symbols defined in a given file. +# * `initial_instructions`: Gets the initial instructions for the current project. +# Should only be used in settings where the system prompt cannot be set, +# e.g. in clients you have no control over, like Claude Desktop. +# * `insert_after_symbol`: Inserts content after the end of the definition of a given symbol. +# * `insert_at_line`: Inserts content at a given line in a file. +# * `insert_before_symbol`: Inserts content before the beginning of the definition of a given symbol. +# * `list_dir`: Lists files and directories in the given directory (optionally with recursion). +# * `list_memories`: Lists memories in Serena's project-specific memory store. +# * `onboarding`: Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building). +# * `prepare_for_new_conversation`: Provides instructions for preparing for a new conversation (in order to continue with the necessary context). +# * `read_file`: Reads a file within the project directory. +# * `read_memory`: Reads the memory with the given name from Serena's project-specific memory store. +# * `remove_project`: Removes a project from the Serena configuration. +# * `replace_lines`: Replaces a range of lines within a file with new content. +# * `replace_symbol_body`: Replaces the full definition of a symbol. +# * `restart_language_server`: Restarts the language server, may be necessary when edits not through Serena happen. +# * `search_for_pattern`: Performs a search for a pattern in the project. +# * `summarize_changes`: Provides instructions for summarizing the changes made to the codebase. +# * `switch_modes`: Activates modes by providing a list of their names +# * `think_about_collected_information`: Thinking tool for pondering the completeness of collected information. +# * `think_about_task_adherence`: Thinking tool for determining whether the agent is still on track with the current task. +# * `think_about_whether_you_are_done`: Thinking tool for determining whether the task is truly completed. +# * `write_memory`: Writes a named memory (for future reference) to Serena's project-specific memory store. +excluded_tools: [] + +# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default). +# This extends the existing inclusions (e.g. from the global configuration). +included_optional_tools: [] + +# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools. +# This cannot be combined with non-empty excluded_tools or included_optional_tools. +fixed_tools: [] + +# list of mode names to that are always to be included in the set of active modes +# The full set of modes to be activated is base_modes + default_modes. +# If the setting is undefined, the base_modes from the global configuration (serena_config.yml) apply. +# Otherwise, this setting overrides the global configuration. +# Set this to [] to disable base modes for this project. +# Set this to a list of mode names to always include the respective modes for this project. +base_modes: + +# list of mode names that are to be activated by default. +# The full set of modes to be activated is base_modes + default_modes. +# If the setting is undefined, the default_modes from the global configuration (serena_config.yml) apply. +# Otherwise, this overrides the setting from the global configuration (serena_config.yml). +# This setting can, in turn, be overridden by CLI parameters (--mode). +default_modes: + +# initial prompt for the project. It will always be given to the LLM upon activating the project +# (contrary to the memories, which are loaded on demand). +initial_prompt: "" + +# time budget (seconds) per tool call for the retrieval of additional symbol information +# such as docstrings or parameter information. +# This overrides the corresponding setting in the global configuration; see the documentation there. +# If null or missing, use the setting from the global configuration. +symbol_info_budget: + +# list of regex patterns which, when matched, mark a memory entry as read‑only. +# Extends the list from the global configuration, merging the two lists. +read_only_memory_patterns: [] From 4c8ddc30f5d4a7a9bc261b8e9d4ad5fbc9629a65 Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 16:37:46 +0000 Subject: [PATCH 05/52] fix review issues --- .../graphql_subscription_client.go | 68 +++++++++++------ .../graphql_subscription_client_test.go | 74 +++++++++---------- .../protocol/graphql_transport_ws.go | 10 +-- .../subscriptionclient/protocol/graphql_ws.go | 8 +- .../transport/sse_transport.go | 19 +++-- .../transport/ws_transport.go | 9 ++- v2/pkg/engine/resolve/resolve.go | 26 ++++++- v2/pkg/engine/resolve/resolve_test.go | 17 ++--- .../resolve/resolver_subscription_test.go | 4 +- 9 files changed, 144 insertions(+), 91 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 7b13b6f547..2914fb7474 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -147,13 +147,40 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu return err } - handler := func(msg *client.Message) { + handler := buildMessageHandler(updater) + + cancel, err := c.client.Subscribe(ctx.Context(), req, opts, handler) + if err != nil { + if isUpstreamError(err) { + updater.Error(formatUpstreamServiceError(err)) + updater.Done() + return nil + } + return err + } + + context.AfterFunc(ctx.Context(), func() { + cancel() + updater.Done() + }) + + return nil +} + +// buildMessageHandler creates the handler that bridges client.Message → resolve.SubscriptionUpdater. +func buildMessageHandler(updater resolve.SubscriptionUpdater) client.Handler { + return func(msg *client.Message) { switch msg.Type { case client.MessageTypeConnectionError: updater.Error(formatUpstreamServiceError(msg.Err)) updater.Done() case client.MessageTypeError: - data, _ := json.Marshal(msg.Payload) + data, err := json.Marshal(msg.Payload) + if err != nil { + updater.Error(formatSubscriptionError(err)) + updater.Done() + return + } updater.Error(data) updater.Done() case client.MessageTypeData: @@ -169,34 +196,33 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu updater.Done() } } - - cancel, err := c.client.Subscribe(ctx.Context(), req, opts, handler) - if err != nil { - if isUpstreamError(err) { - updater.Error(formatUpstreamServiceError(err)) - updater.Done() - return nil - } - return err - } - - context.AfterFunc(ctx.Context(), func() { - cancel() - updater.Done() - }) - - return nil } // isUpstreamError reports whether err is a connection-level upstream error // that should be reported to the client as an UPSTREAM_SERVICE_ERROR. func isUpstreamError(err error) bool { - return errors.Is(err, client.ErrConnectionClosed) || + if errors.Is(err, client.ErrConnectionClosed) || errors.Is(err, client.ErrConnectionError) || errors.Is(err, client.ErrInitFailed) || errors.Is(err, client.ErrDialFailed) || + errors.Is(err, client.ErrAckTimeout) || + errors.Is(err, client.ErrAckNotReceived) || errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) + errors.Is(err, context.DeadlineExceeded) { + return true + } + + var failedUpgrade client.ErrFailedUpgrade + if errors.As(err, &failedUpgrade) { + return true + } + + var invalidSubprotocol client.ErrInvalidSubprotocol + if errors.As(err, &invalidSubprotocol) { + return true + } + + return false } // convertToClientOptions converts GraphQLSubscriptionOptions to the new client's types. diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 1b4a209b7d..c3066296e4 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -3,9 +3,9 @@ package graphql_datasource import ( "context" "encoding/json" - "errors" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" client "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient" @@ -44,64 +44,64 @@ func (t *testBridgeUpdater) Subscriptions() map[context.Context]resolve.Subscrip return map[context.Context]resolve.SubscriptionIdentifier{} } -func TestHandlerDeliversCorrectMessageForEachType(t *testing.T) { - buildHandler := func(updater *testBridgeUpdater) client.Handler { - return func(msg *client.Message) { - switch msg.Type { - case client.MessageTypeConnectionError: - updater.Error(formatUpstreamServiceError(msg.Err)) - updater.Done() - case client.MessageTypeError: - data, _ := json.Marshal(msg.Payload) - updater.Error(data) - updater.Done() - case client.MessageTypeData: - data, err := json.Marshal(msg.Payload) - if err != nil { - updater.Error(formatSubscriptionError(err)) - updater.Done() - return - } - updater.Update(data) - case client.MessageTypeComplete: - updater.Complete() - updater.Done() - } - } - } - - t.Run("connection errors deliver error and done without updates", func(t *testing.T) { +func TestBuildMessageHandlerRoutesEachMessageTypeCorrectly(t *testing.T) { + t.Run("error is upstream service error for connection error", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildHandler(updater) + handler := buildMessageHandler(updater) handler(&client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed}) require.True(t, updater.done) require.Len(t, updater.errors, 1) - require.Len(t, updater.updates, 0) + assert.Contains(t, string(updater.errors[0]), "UPSTREAM_SERVICE_ERROR") + require.Empty(t, updater.updates) require.False(t, updater.completed) }) - t.Run("non-connection errors deliver error and done without updates", func(t *testing.T) { + t.Run("error contains payload for graphql error", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildHandler(updater) + handler := buildMessageHandler(updater) - handler(&client.Message{Type: client.MessageTypeConnectionError, Err: errors.New("validation failed")}) + handler(&client.Message{ + Type: client.MessageTypeError, + Payload: &client.ExecutionResult{ + Errors: json.RawMessage(`[{"message":"field not found"}]`), + }, + }) require.True(t, updater.done) require.Len(t, updater.errors, 1) - require.Len(t, updater.updates, 0) + assert.Contains(t, string(updater.errors[0]), "field not found") + require.Empty(t, updater.updates) require.False(t, updater.completed) }) - t.Run("complete message calls complete then done", func(t *testing.T) { + t.Run("update is delivered without completing for data message", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildHandler(updater) + handler := buildMessageHandler(updater) + + handler(&client.Message{ + Type: client.MessageTypeData, + Payload: &client.ExecutionResult{ + Data: json.RawMessage(`{"foo":"bar"}`), + }, + }) + + require.Len(t, updater.updates, 1) + assert.JSONEq(t, `{"data":{"foo":"bar"}}`, string(updater.updates[0])) + require.False(t, updater.done) + require.False(t, updater.completed) + require.Empty(t, updater.errors) + }) + + t.Run("complete and done are set for complete message", func(t *testing.T) { + updater := &testBridgeUpdater{} + handler := buildMessageHandler(updater) handler(&client.Message{Type: client.MessageTypeComplete}) require.True(t, updater.done) require.True(t, updater.completed) - require.Len(t, updater.errors, 0) + require.Empty(t, updater.errors) }) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index 80bd884f92..a6b5f808ec 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -54,10 +54,6 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay if payload != nil { initMsg.Payload = payload } - if err := wsjson.Write(ctx, conn, initMsg); err != nil { - return fmt.Errorf("write connection_init: %w", err) - } - timeout := p.AckTimeout if timeout == 0 { timeout = 30 * time.Second @@ -66,6 +62,10 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay ackCtx, ackCancel := context.WithTimeout(ctx, timeout) defer ackCancel() + if err := wsjson.Write(ackCtx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + for { var ackMessage incomingMessage if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { @@ -79,7 +79,7 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay case gtwsTypeConnectionAck: return nil case gtwsTypePing: - if err := p.Pong(ctx, conn); err != nil { + if err := p.Pong(ackCtx, conn); err != nil { return fmt.Errorf("pre-init pong: %w", err) } continue diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index aad78508e3..ea103326b1 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -45,10 +45,6 @@ func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[ if payload != nil { initMsg.Payload = payload } - if err := wsjson.Write(ctx, conn, initMsg); err != nil { - return fmt.Errorf("write connection_init: %w", err) - } - timeout := p.AckTimeout if timeout == 0 { timeout = 30 * time.Second @@ -57,6 +53,10 @@ func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[ ackCtx, ackCancel := context.WithTimeout(ctx, timeout) defer ackCancel() + if err := wsjson.Write(ackCtx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + for { var ackMessage incomingMessage if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 5e672ebccd..95e79cd965 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "maps" + "mime" "net/http" "net/url" "strings" @@ -17,6 +18,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) +const maxErrorBodySize = 4096 + // SSETransport implements the Transport interface using Server-Sent Events. // Unlike WebSocket, each subscription creates a separate HTTP request. // TCP connection reuse is handled by http.Client's connection pool. @@ -72,6 +75,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts // Use request context, but with transport requestCancel requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) + defer requestCancel() // Attach cancel to transport context context.AfterFunc(t.ctx, requestCancel) @@ -101,7 +105,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts } if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) resp.Body.Close() t.log.Error("sseTransport.Subscribe", abstractlogger.String("endpoint", opts.Endpoint), @@ -134,6 +138,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts go conn.readLoop() cancelFn := func() { + requestCancel() conn.closeConn() t.removeConn(conn) } @@ -218,12 +223,16 @@ func (t *SSETransport) validateContentType(resp *http.Response) error { return nil // Allow missing content-type } - // Check if it starts with text/event-stream (may include charset) - if strings.HasPrefix(contentType, "text/event-stream") { - return nil + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content-type %q: %w", contentType, err) + } + + if !strings.EqualFold(mediaType, "text/event-stream") { + return fmt.Errorf("unexpected content-type: %s", contentType) } - return fmt.Errorf("unexpected content-type: %s", contentType) + return nil } func (t *SSETransport) removeConn(conn *sseConnection) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 143c975c2c..8536748c4e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -150,6 +150,7 @@ func NewWSTransport(ctx context.Context, opts ...WSTransportOption) *WSTransport upgradeClient: http.DefaultClient, logger: abstractlogger.NoopLogger, readLimit: defaultReadLimit, + writeTimeout: defaultWriteTimeout, } for _, apply := range opts { apply(&o) @@ -208,7 +209,7 @@ func (t *WSTransport) pingLoop() { continue } - if err := conn.sendPing(defaultWriteTimeout); err != nil { + if err := conn.sendPing(t.opts.writeTimeout); err != nil { t.opts.logger.Debug("wsTransport.pingLoop", abstractlogger.String("action", "ping_failed"), abstractlogger.Error(err), @@ -248,7 +249,11 @@ func (t *WSTransport) getOrDial(ctx context.Context, opts common.Options) (*wsCo if result, ok := t.dialing[key]; ok { t.mu.Unlock() - <-result.done + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-result.done: + } if result.err != nil { return nil, result.err diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 310da9d9d0..a6321e2643 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -720,12 +720,13 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:added:%d:%d\n", triggerID, add.id.SubscriptionID) } - // Add the subscription to the registry so it can receive events + // Register first so startup hooks can deliver initial data via UpdateSubscription. trig.mu.Lock() trig.subscriptions[add.id] = s trig.mu.Unlock() r.addSubscriptionIndex(s) - // Execute the startup hooks in a goroutine to avoid holding the lock + // Execute the startup hooks in a goroutine to avoid holding the lock. + // On failure, executeStartupHooks calls UnsubscribeSubscription to clean up. go func() { _ = r.executeStartupHooks(add, trig.updater) }() @@ -986,6 +987,13 @@ type pendingSubscriptionWrite struct { sub *subscriptionState } +type pendingFilterError struct { + ctx *Context + err error + response *GraphQLResponse + writer SubscriptionResponseWriter +} + // handleTriggerUpdate sends data to all subscriptions of a trigger using snapshot-and-release. // The lock is released before performing I/O to avoid deadlocks when executeSubscriptionUpdate // calls AsyncUnsubscribeSubscription on flush failure. @@ -1001,6 +1009,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } var pending []pendingSubscriptionWrite + var filterErrors []pendingFilterError trig.mu.Lock() for _, s := range trig.subscriptions { if s.ctx.ctx.Err() != nil { @@ -1008,7 +1017,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) if err != nil { - r.asyncErrorWriter.WriteError(s.ctx, err, s.resolve.Response, s.writer) + filterErrors = append(filterErrors, pendingFilterError{s.ctx, err, s.resolve.Response, s.writer}) continue } if skip { @@ -1018,6 +1027,10 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } trig.mu.Unlock() + for _, fe := range filterErrors { + r.asyncErrorWriter.WriteError(fe.ctx, fe.err, fe.response, fe.writer) + } + var wg sync.WaitGroup for _, pw := range pending { if pw.sub.removed.Load() { @@ -1044,13 +1057,14 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie } var target *subscriptionState + var filterErr *pendingFilterError trig.mu.Lock() s, ok := trig.subscriptions[subIdentifier] if ok { if s.ctx.ctx.Err() == nil { skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) if err != nil { - r.asyncErrorWriter.WriteError(s.ctx, err, s.resolve.Response, s.writer) + filterErr = &pendingFilterError{s.ctx, err, s.resolve.Response, s.writer} } else if !skip { target = s } @@ -1058,6 +1072,10 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie } trig.mu.Unlock() + if filterErr != nil { + r.asyncErrorWriter.WriteError(filterErr.ctx, filterErr.err, filterErr.response, filterErr.writer) + } + if target != nil && !target.removed.Load() { r.executeSubscriptionUpdate(target.ctx, target, data) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 91123f7953..170f9cd0bf 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "slices" "sync" "sync/atomic" "testing" @@ -5435,7 +5436,7 @@ func (s *SubscriptionRecorder) Heartbeat() error { func (s *SubscriptionRecorder) Error(data []byte) { s.mux.Lock() - s.errors = append(s.errors, data) + s.errors = append(s.errors, slices.Clone(data)) s.mux.Unlock() } @@ -5957,11 +5958,9 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, recorder, id := setup(c, fakeStream) - ctx := Context{ - ctx: context.Background(), - } + ctx := NewContext(context.Background()) - err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) err = resolver.UnsubscribeClient(id.ConnectionID) @@ -6052,12 +6051,8 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { // Each subscription needs its own Context so they get separate entries // in the trigger's subscriptions map (keyed by *Context). - subCtx := &Context{ - ctx: context.Background(), - ExecutionOptions: ExecutionOptions{ - SendHeartbeat: true, - }, - } + subCtx := NewContext(context.Background()) + subCtx.ExecutionOptions.SendHeartbeat = true go func() { defer recorderCompleted.Add(1) diff --git a/v2/pkg/engine/resolve/resolver_subscription_test.go b/v2/pkg/engine/resolve/resolver_subscription_test.go index 9e8ac319fd..dbaed8901d 100644 --- a/v2/pkg/engine/resolve/resolver_subscription_test.go +++ b/v2/pkg/engine/resolve/resolver_subscription_test.go @@ -209,10 +209,10 @@ func TestResolver_HeartbeatError_DoesNotDeadlockOnUnsubscribe(t *testing.T) { resolver := New(resolverCtx, ResolverOptions{ MaxConcurrency: 1, AsyncErrorWriter: &FakeErrorWriter{}, - SubscriptionHeartbeatInterval: time.Millisecond, + SubscriptionHeartbeatInterval: time.Hour, // Long interval to prevent background heartbeat loop from competing }) - subCtx := (&Context{}).WithContext(context.Background()) + subCtx := NewContext(context.Background()) subID := SubscriptionIdentifier{ ConnectionID: 1, SubscriptionID: 1, From e91088033681e3b3f8832832bee0ed98d90ac8fa Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 17:26:24 +0000 Subject: [PATCH 06/52] simplify connection and transport construction --- .../subscriptionclient/client.go | 18 +- .../transport/sse_transport.go | 2 - .../transport/sse_transport_test.go | 4 +- .../subscriptionclient/transport/ws_conn.go | 45 +---- .../transport/ws_conn_test.go | 82 ++++----- .../transport/ws_transport.go | 163 ++++++------------ .../transport/ws_transport_test.go | 57 +++--- 7 files changed, 137 insertions(+), 234 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index f4b3619d5f..d04e494db6 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -56,15 +56,15 @@ func New(ctx context.Context, cfg Config) *Client { ctx: ctx, log: cfg.Logger, - ws: transport.NewWSTransport(ctx, - transport.WithUpgradeClient(cfg.UpgradeClient), - transport.WithLogger(cfg.Logger), - transport.WithPingInterval(cfg.PingInterval), - transport.WithPingTimeout(cfg.PingTimeout), - transport.WithAckTimeout(cfg.AckTimeout), - transport.WithWriteTimeout(cfg.WriteTimeout), - transport.WithReadLimit(cfg.ReadLimit), - ), + ws: transport.NewWSTransport(ctx, transport.WSTransportOptions{ + UpgradeClient: cfg.UpgradeClient, + Logger: cfg.Logger, + PingInterval: cfg.PingInterval, + PingTimeout: cfg.PingTimeout, + AckTimeout: cfg.AckTimeout, + WriteTimeout: cfg.WriteTimeout, + ReadLimit: cfg.ReadLimit, + }), sse: transport.NewSSETransport(ctx, cfg.StreamingClient, cfg.Logger), } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 95e79cd965..9b8a51d0da 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -73,9 +73,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts abstractlogger.String("method", string(method)), ) - // Use request context, but with transport requestCancel requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) - defer requestCancel() // Attach cancel to transport context context.AfterFunc(t.ctx, requestCancel) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go index 6d8514ac68..8cd7104b79 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -540,7 +540,9 @@ func TestSSETransport_ContextCancellation(t *testing.T) { return closedCount.Load() == 2 }, time.Second, 10*time.Millisecond) - assert.Equal(t, 0, tr.ConnCount()) + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) }) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 672fc2fa5e..f07a33d3e9 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -28,34 +28,6 @@ type wsConnectionOptions struct { onEmpty func() } -// wsConnectionOption configures a wsConnection. -type wsConnectionOption func(*wsConnectionOptions) - -// withConnLogger sets the logger for connection-level debug output. -func withConnLogger(l abstractlogger.Logger) wsConnectionOption { - return func(o *wsConnectionOptions) { - if l != nil { - o.logger = l - } - } -} - -// withConnWriteTimeout sets the timeout for write operations (subscribe, unsubscribe, pong). -func withConnWriteTimeout(d time.Duration) wsConnectionOption { - return func(o *wsConnectionOptions) { - if d > 0 { - o.writeTimeout = d - } - } -} - -// withOnEmpty sets a callback invoked when the last subscription is removed or the connection shuts down. -func withOnEmpty(f func()) wsConnectionOption { - return func(o *wsConnectionOptions) { - o.onEmpty = f - } -} - type wsConnection struct { conn *websocket.Conn protocol protocol.Protocol @@ -81,13 +53,12 @@ type wsConnection struct { lastPongAt atomic.Int64 } -func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts ...wsConnectionOption) *wsConnection { - o := wsConnectionOptions{ - logger: abstractlogger.NoopLogger, - writeTimeout: defaultWriteTimeout, +func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConnectionOptions) *wsConnection { + if opts.logger == nil { + opts.logger = abstractlogger.NoopLogger } - for _, apply := range opts { - apply(&o) + if opts.writeTimeout <= 0 { + opts.writeTimeout = defaultWriteTimeout } ctx, cancel := context.WithCancel(context.Background()) @@ -95,13 +66,13 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts ...wsCo c := &wsConnection{ conn: conn, protocol: proto, - log: o.logger, + log: opts.logger, cancel: cancel, ctx: ctx, subs: make(map[string]common.Handler), - onEmpty: o.onEmpty, + onEmpty: opts.onEmpty, - writeTimeout: o.writeTimeout, + writeTimeout: opts.writeTimeout, } c.lastPongAt.Store(time.Now().UnixNano()) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 69d2fa0659..fd30c6cf8f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -26,7 +26,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) handler, _ := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ @@ -44,7 +44,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -60,7 +60,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) wsc.closeConn() _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) @@ -74,7 +74,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.subscribeErr = assert.AnError - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) @@ -91,7 +91,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -116,7 +116,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -138,7 +138,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) go wsc.readLoop() @@ -154,7 +154,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -188,7 +188,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -205,7 +205,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -223,9 +223,9 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.unsubscribeDelay = 500 * time.Millisecond - wsc := newWSConnection(conn, proto, - withConnWriteTimeout(50*time.Millisecond), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + writeTimeout: 50 * time.Millisecond, + }) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -248,9 +248,9 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, - withOnEmpty(func() { emptyCalled <- struct{}{} }), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) cancel, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) cancel() @@ -272,9 +272,9 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, - withOnEmpty(func() { emptyCalled <- struct{}{} }), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) @@ -305,9 +305,9 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, - withOnEmpty(func() { emptyCalled <- struct{}{} }), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) wsc.closeConn() @@ -326,9 +326,9 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, - withOnEmpty(func() { emptyCalled <- struct{}{} }), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) go wsc.readLoop() @@ -352,7 +352,7 @@ func TestWSConnection_Close(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) handler1, receive1 := collectingHandler() _, _ = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler1) @@ -376,7 +376,7 @@ func TestWSConnection_Close(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) assert.NotPanics(t, func() { wsc.closeConn() @@ -394,7 +394,7 @@ func TestWSConnection_SubCount(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) assert.Equal(t, 0, wsc.subCount()) @@ -421,9 +421,9 @@ func TestWSConnection_WriteTimeout(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.pongDelay = 500 * time.Millisecond - wsc := newWSConnection(conn, proto, - withConnWriteTimeout(50*time.Millisecond), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + writeTimeout: 50 * time.Millisecond, + }) handler, receive := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -457,7 +457,7 @@ func TestWSConnection_Defaults(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto) + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) }) @@ -467,9 +467,9 @@ func TestWSConnection_Defaults(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, - withConnWriteTimeout(0), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + writeTimeout: 0, + }) assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) }) @@ -479,9 +479,9 @@ func TestWSConnection_Defaults(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, - withConnWriteTimeout(10*time.Second), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + writeTimeout: 10 * time.Second, + }) assert.Equal(t, 10*time.Second, wsc.writeTimeoutDuration()) }) @@ -491,9 +491,9 @@ func TestWSConnection_Defaults(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, - withConnWriteTimeout(-1*time.Second), - ) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + writeTimeout: -1 * time.Second, + }) assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 8536748c4e..2712ed0187 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -42,90 +42,20 @@ func (e ErrInvalidSubprotocol) Error() string { return fmt.Sprintf("provided websocket subprotocol '%s' is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subscriptions with the mentioned subprotocols", string(e)) } -type wsTransportOptions struct { - upgradeClient *http.Client - logger abstractlogger.Logger - pingInterval time.Duration - pingTimeout time.Duration - ackTimeout time.Duration - writeTimeout time.Duration - readLimit int64 -} - -// WSTransportOption configures a WSTransport. -type WSTransportOption func(*wsTransportOptions) - -// WithUpgradeClient sets the HTTP client used for WebSocket upgrade requests. -func WithUpgradeClient(c *http.Client) WSTransportOption { - return func(o *wsTransportOptions) { - if c != nil { - o.upgradeClient = c - } - } -} - -// WithLogger sets the logger for transport-level debug output. -func WithLogger(l abstractlogger.Logger) WSTransportOption { - return func(o *wsTransportOptions) { - if l != nil { - o.logger = l - } - } -} - -// WithPingInterval sets how often protocol-level pings are sent to all connections. -// Zero disables pinging. -func WithPingInterval(d time.Duration) WSTransportOption { - return func(o *wsTransportOptions) { - if d > 0 { - o.pingInterval = d - } - } -} - -// WithPingTimeout sets how long a connection may go without a pong before being closed. -// Zero disables the timeout (pings are sent but unresponsive connections are not killed). -func WithPingTimeout(d time.Duration) WSTransportOption { - return func(o *wsTransportOptions) { - if d > 0 { - o.pingTimeout = d - } - } -} - -// WithAckTimeout sets the maximum time to wait for a connection_ack after sending -// connection_init. Zero uses the protocol default (30s). -func WithAckTimeout(d time.Duration) WSTransportOption { - return func(o *wsTransportOptions) { - if d > 0 { - o.ackTimeout = d - } - } -} - -// WithWriteTimeout sets the timeout for WebSocket write operations on new connections. -// Zero uses defaultWriteTimeout (5s) at the connection level. -func WithWriteTimeout(d time.Duration) WSTransportOption { - return func(o *wsTransportOptions) { - if d > 0 { - o.writeTimeout = d - } - } -} - -// WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. -// Zero uses the defaultReadLimit (1MB). -func WithReadLimit(n int64) WSTransportOption { - return func(o *wsTransportOptions) { - if n > 0 { - o.readLimit = n - } - } +// WSTransportOptions configures a WSTransport. +type WSTransportOptions struct { + UpgradeClient *http.Client + Logger abstractlogger.Logger + PingInterval time.Duration + PingTimeout time.Duration + AckTimeout time.Duration + WriteTimeout time.Duration + ReadLimit int64 } type WSTransport struct { ctx context.Context - opts wsTransportOptions + opts WSTransportOptions mu sync.Mutex dialing map[uint64]*dialResult @@ -142,28 +72,31 @@ type dialResult struct { // is cancelled; instead they close themselves when their last subscriber is // removed via the resolver's drain chain. The ping loop exits on ctx cancellation. // -// If WithPingInterval is set, a single goroutine sends protocol-level pings to all -// connections at that cadence. If WithPingTimeout is also set, connections that fail +// If PingInterval is set, a single goroutine sends protocol-level pings to all +// connections at that cadence. If PingTimeout is also set, connections that fail // to respond with a pong within that window are shut down. -func NewWSTransport(ctx context.Context, opts ...WSTransportOption) *WSTransport { - o := wsTransportOptions{ - upgradeClient: http.DefaultClient, - logger: abstractlogger.NoopLogger, - readLimit: defaultReadLimit, - writeTimeout: defaultWriteTimeout, +func NewWSTransport(ctx context.Context, opts WSTransportOptions) *WSTransport { + if opts.UpgradeClient == nil { + opts.UpgradeClient = http.DefaultClient } - for _, apply := range opts { - apply(&o) + if opts.Logger == nil { + opts.Logger = abstractlogger.NoopLogger + } + if opts.ReadLimit <= 0 { + opts.ReadLimit = defaultReadLimit + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = defaultWriteTimeout } t := &WSTransport{ ctx: ctx, - opts: o, + opts: opts, conns: make(map[uint64]*wsConnection), dialing: make(map[uint64]*dialResult), } - if o.pingInterval > 0 { + if opts.PingInterval > 0 { go t.pingLoop() } @@ -183,7 +116,7 @@ func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts c // pingLoop sends periodic pings to all active connections and shuts down // any that have not responded with a pong in time. func (t *WSTransport) pingLoop() { - tick := time.Tick(t.opts.pingInterval) + tick := time.Tick(t.opts.PingInterval) for { select { case <-t.ctx.Done(): @@ -201,16 +134,16 @@ func (t *WSTransport) pingLoop() { continue } - if t.opts.pingTimeout > 0 && conn.pongOverdue(t.opts.pingTimeout) { - t.opts.logger.Debug("wsTransport.pingLoop", + if t.opts.PingTimeout > 0 && conn.pongOverdue(t.opts.PingTimeout) { + t.opts.Logger.Debug("wsTransport.pingLoop", abstractlogger.String("action", "pong_timeout"), ) conn.closeConn() continue } - if err := conn.sendPing(t.opts.writeTimeout); err != nil { - t.opts.logger.Debug("wsTransport.pingLoop", + if err := conn.sendPing(t.opts.WriteTimeout); err != nil { + t.opts.Logger.Debug("wsTransport.pingLoop", abstractlogger.String("action", "ping_failed"), abstractlogger.Error(err), ) @@ -222,12 +155,12 @@ func (t *WSTransport) pingLoop() { // ReadLimit returns the configured read limit. func (t *WSTransport) ReadLimit() int64 { - return t.opts.readLimit + return t.opts.ReadLimit } // WriteTimeout returns the configured write timeout for new connections. func (t *WSTransport) WriteTimeout() time.Duration { - return t.opts.writeTimeout + return t.opts.WriteTimeout } func (t *WSTransport) ConnCount() int { @@ -284,13 +217,13 @@ func (t *WSTransport) getOrDial(ctx context.Context, opts common.Options) (*wsCo } func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) (*wsConnection, error) { - t.opts.logger.Debug("wsTransport.dial", + t.opts.Logger.Debug("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.String("subprotocol", string(opts.WSSubprotocol)), ) wsConn, resp, err := websocket.Dial(ctx, opts.Endpoint, &websocket.DialOptions{ //nolint:bodyclose - HTTPClient: t.opts.upgradeClient, + HTTPClient: t.opts.UpgradeClient, Subprotocols: opts.WSSubprotocol.Subprotocols(), HTTPHeader: opts.Headers, }) @@ -299,7 +232,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) return nil, err } - t.opts.logger.Error("wsTransport.dial", + t.opts.Logger.Error("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.Error(err), ) @@ -312,11 +245,11 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) return nil, fmt.Errorf("%w: %w", ErrDialFailed, err) } - wsConn.SetReadLimit(t.opts.readLimit) + wsConn.SetReadLimit(t.opts.ReadLimit) proto, err := t.negotiateSubprotocol(opts.WSSubprotocol, wsConn.Subprotocol()) if err != nil { - t.opts.logger.Error("wsTransport.dial", + t.opts.Logger.Error("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.String("error", "subprotocol negotiation failed"), abstractlogger.Error(err), @@ -326,7 +259,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) } if err := proto.Init(ctx, wsConn, opts.InitPayload); err != nil { - t.opts.logger.Error("wsTransport.dial", + t.opts.Logger.Error("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.String("error", "protocol init failed"), abstractlogger.Error(err), @@ -335,17 +268,17 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) return nil, fmt.Errorf("%w: %w", ErrInitFailed, err) } - t.opts.logger.Debug("wsTransport.dial", + t.opts.Logger.Debug("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.String("status", "connected"), abstractlogger.String("negotiated_subprotocol", wsConn.Subprotocol()), ) - conn := newWSConnection(wsConn, proto, - withConnLogger(t.opts.logger), - withConnWriteTimeout(t.opts.writeTimeout), - withOnEmpty(func() { t.removeConn(key) }), - ) + conn := newWSConnection(wsConn, proto, wsConnectionOptions{ + logger: t.opts.Logger, + writeTimeout: t.opts.WriteTimeout, + onEmpty: func() { t.removeConn(key) }, + }) go conn.readLoop() @@ -362,14 +295,14 @@ func (t *WSTransport) negotiateSubprotocol(requested common.WSSubprotocol, accep switch common.WSSubprotocol(accepted) { case common.SubprotocolGraphQLTransportWS: p := protocol.NewGraphQLTransportWS() - if t.opts.ackTimeout > 0 { - p.AckTimeout = t.opts.ackTimeout + if t.opts.AckTimeout > 0 { + p.AckTimeout = t.opts.AckTimeout } return p, nil case common.SubprotocolGraphQLWS: p := protocol.NewGraphQLWS() - if t.opts.ackTimeout > 0 { - p.AckTimeout = t.opts.ackTimeout + if t.opts.AckTimeout > 0 { + p.AckTimeout = t.opts.AckTimeout } return p, nil default: diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 7f8e5227a4..49ceb5c090 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -43,7 +43,7 @@ func TestWSTransport_Subscribe(t *testing.T) { }) }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -85,7 +85,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -134,7 +134,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) headers1 := http.Header{"Authorization": []string{"Bearer token1"}} headers2 := http.Header{"Authorization": []string{"Bearer token2"}} @@ -188,7 +188,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler1, receive1 := collectingHandler() cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ @@ -228,7 +228,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -277,7 +277,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -322,7 +322,7 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ Endpoint: server.URL, @@ -351,7 +351,7 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) opts := common.Options{Endpoint: server.URL, Transport: common.TransportWS} @@ -401,7 +401,7 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -477,7 +477,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) initPayload := map[string]any{ "Authorization": "Bearer secret-token", @@ -562,7 +562,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) initPayload := map[string]any{ "token": "legacy-auth-token", @@ -641,7 +641,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -720,7 +720,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) // First subscription with user1 token handler1, receive1 := collectingHandler() @@ -795,7 +795,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -840,7 +840,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -881,7 +881,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -924,7 +924,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) + tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -958,7 +958,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WithPingInterval(100*time.Millisecond), WithPingTimeout(50*time.Millisecond)) + tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 100 * time.Millisecond, PingTimeout: 50 * time.Millisecond}) handler, receive := collectingHandler() _, err := tr.Subscribe(context.Background(), &common.Request{ @@ -1001,7 +1001,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { }) // PingInterval set, PingTimeout left at zero (disabled) - tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond)) + tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -1041,7 +1041,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WithPingInterval(50*time.Millisecond), WithPingTimeout(200*time.Millisecond)) + tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond, PingTimeout: 200 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -1064,7 +1064,7 @@ func TestWSTransport_Defaults(t *testing.T) { t.Run("applies default read limit when omitted", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) assert.Equal(t, defaultReadLimit, tr.ReadLimit()) }) @@ -1072,7 +1072,7 @@ func TestWSTransport_Defaults(t *testing.T) { t.Run("applies default read limit for zero value", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context(), WithReadLimit(0)) + tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 0}) assert.Equal(t, defaultReadLimit, tr.ReadLimit()) }) @@ -1080,7 +1080,7 @@ func TestWSTransport_Defaults(t *testing.T) { t.Run("overrides read limit when provided", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context(), WithReadLimit(2*1024*1024)) + tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 2 * 1024 * 1024}) assert.Equal(t, int64(2*1024*1024), tr.ReadLimit()) }) @@ -1088,24 +1088,23 @@ func TestWSTransport_Defaults(t *testing.T) { t.Run("ignores negative read limit", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context(), WithReadLimit(-1)) + tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: -1}) assert.Equal(t, defaultReadLimit, tr.ReadLimit()) }) - t.Run("applies zero write timeout by default", func(t *testing.T) { + t.Run("applies default write timeout", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context()) + tr := NewWSTransport(t.Context(), WSTransportOptions{}) - // Zero means connections use their own DefaultWriteTimeout - assert.Equal(t, time.Duration(0), tr.WriteTimeout()) + assert.Equal(t, defaultWriteTimeout, tr.WriteTimeout()) }) t.Run("overrides write timeout when provided", func(t *testing.T) { t.Parallel() - tr := NewWSTransport(t.Context(), WithWriteTimeout(10*time.Second)) + tr := NewWSTransport(t.Context(), WSTransportOptions{WriteTimeout: 10 * time.Second}) assert.Equal(t, 10*time.Second, tr.WriteTimeout()) }) From 97596857a236c49f64ea57e21b9cd86c3477a76c Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 17:52:16 +0000 Subject: [PATCH 07/52] fix lint issue --- .../graphql_datasource/graphql_subscription_client.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 2914fb7474..8c509d5580 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -218,11 +218,7 @@ func isUpstreamError(err error) bool { } var invalidSubprotocol client.ErrInvalidSubprotocol - if errors.As(err, &invalidSubprotocol) { - return true - } - - return false + return errors.As(err, &invalidSubprotocol) } // convertToClientOptions converts GraphQLSubscriptionOptions to the new client's types. From cb2348a61a492df70dcb915161a7eda377d092be Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 18:25:26 +0000 Subject: [PATCH 08/52] add idle timeout before closing unused connections --- .../subscriptionclient/client.go | 2 + .../subscriptionclient/transport/ws_conn.go | 18 ++++- .../transport/ws_conn_test.go | 76 +++++++++++++++++++ .../transport/ws_transport.go | 5 ++ .../transport/ws_transport_test.go | 7 +- 5 files changed, 103 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index d04e494db6..6f18505408 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -38,6 +38,7 @@ type Config struct { AckTimeout time.Duration WriteTimeout time.Duration ReadLimit int64 + WSIdleTimeout time.Duration } // New creates a new subscription client with the provided config. @@ -64,6 +65,7 @@ func New(ctx context.Context, cfg Config) *Client { AckTimeout: cfg.AckTimeout, WriteTimeout: cfg.WriteTimeout, ReadLimit: cfg.ReadLimit, + IdleTimeout: cfg.WSIdleTimeout, }), sse: transport.NewSSETransport(ctx, cfg.StreamingClient, cfg.Logger), } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index f07a33d3e9..4d9b50512a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -25,6 +25,7 @@ var ( type wsConnectionOptions struct { logger abstractlogger.Logger writeTimeout time.Duration + idleTimeout time.Duration onEmpty func() } @@ -43,7 +44,8 @@ type wsConnection struct { closed atomic.Bool - onEmpty func() + onEmpty func() + idleTimeout time.Duration writeTimeout time.Duration @@ -73,6 +75,7 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConne onEmpty: opts.onEmpty, writeTimeout: opts.writeTimeout, + idleTimeout: opts.idleTimeout, } c.lastPongAt.Store(time.Now().UnixNano()) @@ -121,7 +124,18 @@ func (c *wsConnection) removeSub(id string) { c.subsMu.Unlock() if isEmpty { - c.closeConn() + if c.idleTimeout > 0 { + time.AfterFunc(c.idleTimeout, func() { + c.subsMu.RLock() + stillEmpty := len(c.subs) == 0 + c.subsMu.RUnlock() + if stillEmpty { + c.closeConn() + } + }) + } else { + c.closeConn() + } } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index fd30c6cf8f..5e2b874446 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -344,6 +344,82 @@ func TestWSConnection_OnEmpty(t *testing.T) { }) } +func TestWSConnection_IdleTimeout(t *testing.T) { + t.Parallel() + + t.Run("removeSub defers close for idle timeout duration when subs are empty", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + idleTimeout: 200 * time.Millisecond, + }) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + + assert.False(t, wsc.isClosed(), "connection should not be closed immediately") + + assert.Eventually(t, func() bool { + return wsc.isClosed() + }, time.Second, 10*time.Millisecond, "connection should close after idle timeout") + }) + + t.Run("removeSub does not close when new subscription arrives before timeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + idleTimeout: 200 * time.Millisecond, + }) + + cancel1, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel1() // starts idle timer + + _, err = wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + time.Sleep(300 * time.Millisecond) + + assert.False(t, wsc.isClosed(), "connection should stay open while subscription exists") + }) + + t.Run("removeSub closes immediately when idle timeout is zero", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newWSConnection(conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("connection should close immediately with zero idle timeout") + } + + assert.True(t, wsc.isClosed()) + }) + +} + func TestWSConnection_Close(t *testing.T) { t.Parallel() diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 2712ed0187..973eb4ffbe 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -51,6 +51,10 @@ type WSTransportOptions struct { AckTimeout time.Duration WriteTimeout time.Duration ReadLimit int64 + // IdleTimeout is the duration a connection stays open after its last + // subscription is removed, allowing new subscriptions to reuse it + // without re-dialing. Zero means close immediately. + IdleTimeout time.Duration } type WSTransport struct { @@ -277,6 +281,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) conn := newWSConnection(wsConn, proto, wsConnectionOptions{ logger: t.opts.Logger, writeTimeout: t.opts.WriteTimeout, + idleTimeout: t.opts.IdleTimeout, onEmpty: func() { t.removeConn(key) }, }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 49ceb5c090..e952fe9512 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -401,7 +401,9 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := NewWSTransport(t.Context(), WSTransportOptions{ + IdleTimeout: 30 * time.Second, + }) opts := common.Options{ Endpoint: server.URL, @@ -424,8 +426,7 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { wg.Wait() - // Should have only dialed once (or maybe twice due to race, but not 10 times) - assert.LessOrEqual(t, dialCount.Load(), int32(2)) + assert.Equal(t, int32(1), dialCount.Load()) }) } From d78d83c68a6338a00300317d05921e46d0248ea2 Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 23 Mar 2026 19:04:06 +0000 Subject: [PATCH 09/52] restore backwards compatible errors --- .serena/project.yml | 14 ++++++++++++++ .../graphql_subscription_client.go | 17 +++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.serena/project.yml b/.serena/project.yml index cfa9063fb9..22f8f0c261 100644 --- a/.serena/project.yml +++ b/.serena/project.yml @@ -136,3 +136,17 @@ symbol_info_budget: # list of regex patterns which, when matched, mark a memory entry as read‑only. # Extends the list from the global configuration, merging the two lists. read_only_memory_patterns: [] + +# list of regex patterns for memories to completely ignore. +# Matching memories will not appear in list_memories or activate_project output +# and cannot be accessed via read_memory or write_memory. +# To access ignored memory files, use the read_file tool on the raw file path. +# Extends the list from the global configuration, merging the two lists. +# Example: ["_archive/.*", "_episodes/.*"] +ignored_memory_patterns: [] + +# advanced configuration option allowing to configure language server-specific options. +# Maps the language key to the options. +# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available. +# No documentation on options means no options are available. +ls_specific_settings: {} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 8c509d5580..7ca8ebab64 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -200,25 +200,18 @@ func buildMessageHandler(updater resolve.SubscriptionUpdater) client.Handler { // isUpstreamError reports whether err is a connection-level upstream error // that should be reported to the client as an UPSTREAM_SERVICE_ERROR. +// ErrFailedUpgrade and ErrInvalidSubprotocol are intentionally excluded so +// they propagate to the router, which formats detailed error messages +// (e.g. including the subgraph name and HTTP status code). func isUpstreamError(err error) bool { - if errors.Is(err, client.ErrConnectionClosed) || + return errors.Is(err, client.ErrConnectionClosed) || errors.Is(err, client.ErrConnectionError) || errors.Is(err, client.ErrInitFailed) || errors.Is(err, client.ErrDialFailed) || errors.Is(err, client.ErrAckTimeout) || errors.Is(err, client.ErrAckNotReceived) || errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) { - return true - } - - var failedUpgrade client.ErrFailedUpgrade - if errors.As(err, &failedUpgrade) { - return true - } - - var invalidSubprotocol client.ErrInvalidSubprotocol - return errors.As(err, &invalidSubprotocol) + errors.Is(err, context.DeadlineExceeded) } // convertToClientOptions converts GraphQLSubscriptionOptions to the new client's types. From 4cc007ce23cf61c1dd8f26b1283ffc48219ec5d5 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 12:10:30 +0100 Subject: [PATCH 10/52] remove unused test helper --- .../graphql_datasource_test.go | 92 ------------------- 1 file changed, 92 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 9432dab865..f1ab22708c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -8381,98 +8381,6 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -type testSubscriptionUpdaterChan struct { - updates chan string - complete chan struct{} - errors chan []byte - done chan struct{} -} - -func newTestSubscriptionUpdaterChan() *testSubscriptionUpdaterChan { - return &testSubscriptionUpdaterChan{ - updates: make(chan string), - complete: make(chan struct{}), - errors: make(chan []byte, 1), - done: make(chan struct{}), - } -} - -func (t *testSubscriptionUpdaterChan) Heartbeat() { - t.updates <- "{}" -} - -func (t *testSubscriptionUpdaterChan) Update(data []byte) { - t.updates <- string(data) -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) { -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) CloseSubscription(id resolve.SubscriptionIdentifier) { -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { - return make(map[context.Context]resolve.SubscriptionIdentifier) -} - -func (t *testSubscriptionUpdaterChan) Complete() { - close(t.complete) -} - -func (t *testSubscriptionUpdaterChan) Error(data []byte) { - t.errors <- data -} - -func (t *testSubscriptionUpdaterChan) Done() { - close(t.done) -} - -func (t *testSubscriptionUpdaterChan) AwaitUpdateWithT(tt *testing.T, timeout time.Duration, f func(t *testing.T, update string), msgAndArgs ...any) { - tt.Helper() - - select { - case args := <-t.updates: - f(tt, args) - case <-time.After(timeout): - require.Fail(tt, "unable to receive update before timeout", msgAndArgs...) - } -} - -func (t *testSubscriptionUpdaterChan) AwaitError(tt *testing.T, timeout time.Duration, msgAndArgs ...any) []byte { - tt.Helper() - - select { - case data := <-t.errors: - return data - case <-time.After(timeout): - require.Fail(tt, "updater error not received before timeout", msgAndArgs...) - return nil - } -} - -func (t *testSubscriptionUpdaterChan) AwaitDone(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { - tt.Helper() - - select { - case <-t.done: - case <-time.After(timeout): - require.Fail(tt, "updater not done before timeout", msgAndArgs...) - } -} - -func (t *testSubscriptionUpdaterChan) AwaitComplete(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { - tt.Helper() - - select { - case <-t.complete: - case <-time.After(timeout): - require.Fail(tt, "updater not completed before timeout", msgAndArgs...) - } -} - // !! If you see this in a test you're working on, please replace it with the new testSubscriptionUpdaterChan // It's faster, more ergonomic and more reliable. See SSE handler tests for usage examples. type testSubscriptionUpdater struct { From 3eb0a1be4d9f0437c19ca4dfa7faa101a61ac332 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 12:12:31 +0100 Subject: [PATCH 11/52] remove serena from git --- .gitignore | 3 +- .serena/.gitignore | 2 - .serena/project.yml | 152 -------------------------------------------- 3 files changed, 2 insertions(+), 155 deletions(-) delete mode 100644 .serena/.gitignore delete mode 100644 .serena/project.yml diff --git a/.gitignore b/.gitignore index 960ca0c8d5..17c4571439 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ .DS_Store pkg/parser/testdata/lotto.graphql *node_modules* -*vendor* \ No newline at end of file +*vendor* +.serena \ No newline at end of file diff --git a/.serena/.gitignore b/.serena/.gitignore deleted file mode 100644 index 2e510aff58..0000000000 --- a/.serena/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/cache -/project.local.yml diff --git a/.serena/project.yml b/.serena/project.yml deleted file mode 100644 index 22f8f0c261..0000000000 --- a/.serena/project.yml +++ /dev/null @@ -1,152 +0,0 @@ -# the name by which the project can be referenced within Serena -project_name: "graphql-go-tools" - - -# list of languages for which language servers are started; choose from: -# al bash clojure cpp csharp -# csharp_omnisharp dart elixir elm erlang -# fortran fsharp go groovy haskell -# java julia kotlin lua markdown -# matlab nix pascal perl php -# php_phpactor powershell python python_jedi r -# rego ruby ruby_solargraph rust scala -# swift terraform toml typescript typescript_vts -# vue yaml zig -# (This list may be outdated. For the current list, see values of Language enum here: -# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py -# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.) -# Note: -# - For C, use cpp -# - For JavaScript, use typescript -# - For Free Pascal/Lazarus, use pascal -# Special requirements: -# Some languages require additional setup/installations. -# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers -# When using multiple languages, the first language server that supports a given file will be used for that file. -# The first language is the default language and the respective language server will be used as a fallback. -# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored. -languages: -- go - -# the encoding used by text files in the project -# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings -encoding: "utf-8" - -# line ending convention to use when writing source files. -# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default) -# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings. -line_ending: - -# The language backend to use for this project. -# If not set, the global setting from serena_config.yml is used. -# Valid values: LSP, JetBrains -# Note: the backend is fixed at startup. If a project with a different backend -# is activated post-init, an error will be returned. -language_backend: - -# whether to use project's .gitignore files to ignore files -ignore_all_files_in_gitignore: true - -# list of additional paths to ignore in this project. -# Same syntax as gitignore, so you can use * and **. -# Note: global ignored_paths from serena_config.yml are also applied additively. -ignored_paths: [] - -# whether the project is in read-only mode -# If set to true, all editing tools will be disabled and attempts to use them will result in an error -# Added on 2025-04-18 -read_only: false - -# list of tool names to exclude. -# This extends the existing exclusions (e.g. from the global configuration) -# -# Below is the complete list of tools for convenience. -# To make sure you have the latest list of tools, and to view their descriptions, -# execute `uv run scripts/print_tool_overview.py`. -# -# * `activate_project`: Activates a project by name. -# * `check_onboarding_performed`: Checks whether project onboarding was already performed. -# * `create_text_file`: Creates/overwrites a file in the project directory. -# * `delete_lines`: Deletes a range of lines within a file. -# * `delete_memory`: Deletes a memory from Serena's project-specific memory store. -# * `execute_shell_command`: Executes a shell command. -# * `find_referencing_code_snippets`: Finds code snippets in which the symbol at the given location is referenced. -# * `find_referencing_symbols`: Finds symbols that reference the symbol at the given location (optionally filtered by type). -# * `find_symbol`: Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type). -# * `get_current_config`: Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes. -# * `get_symbols_overview`: Gets an overview of the top-level symbols defined in a given file. -# * `initial_instructions`: Gets the initial instructions for the current project. -# Should only be used in settings where the system prompt cannot be set, -# e.g. in clients you have no control over, like Claude Desktop. -# * `insert_after_symbol`: Inserts content after the end of the definition of a given symbol. -# * `insert_at_line`: Inserts content at a given line in a file. -# * `insert_before_symbol`: Inserts content before the beginning of the definition of a given symbol. -# * `list_dir`: Lists files and directories in the given directory (optionally with recursion). -# * `list_memories`: Lists memories in Serena's project-specific memory store. -# * `onboarding`: Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building). -# * `prepare_for_new_conversation`: Provides instructions for preparing for a new conversation (in order to continue with the necessary context). -# * `read_file`: Reads a file within the project directory. -# * `read_memory`: Reads the memory with the given name from Serena's project-specific memory store. -# * `remove_project`: Removes a project from the Serena configuration. -# * `replace_lines`: Replaces a range of lines within a file with new content. -# * `replace_symbol_body`: Replaces the full definition of a symbol. -# * `restart_language_server`: Restarts the language server, may be necessary when edits not through Serena happen. -# * `search_for_pattern`: Performs a search for a pattern in the project. -# * `summarize_changes`: Provides instructions for summarizing the changes made to the codebase. -# * `switch_modes`: Activates modes by providing a list of their names -# * `think_about_collected_information`: Thinking tool for pondering the completeness of collected information. -# * `think_about_task_adherence`: Thinking tool for determining whether the agent is still on track with the current task. -# * `think_about_whether_you_are_done`: Thinking tool for determining whether the task is truly completed. -# * `write_memory`: Writes a named memory (for future reference) to Serena's project-specific memory store. -excluded_tools: [] - -# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default). -# This extends the existing inclusions (e.g. from the global configuration). -included_optional_tools: [] - -# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools. -# This cannot be combined with non-empty excluded_tools or included_optional_tools. -fixed_tools: [] - -# list of mode names to that are always to be included in the set of active modes -# The full set of modes to be activated is base_modes + default_modes. -# If the setting is undefined, the base_modes from the global configuration (serena_config.yml) apply. -# Otherwise, this setting overrides the global configuration. -# Set this to [] to disable base modes for this project. -# Set this to a list of mode names to always include the respective modes for this project. -base_modes: - -# list of mode names that are to be activated by default. -# The full set of modes to be activated is base_modes + default_modes. -# If the setting is undefined, the default_modes from the global configuration (serena_config.yml) apply. -# Otherwise, this overrides the setting from the global configuration (serena_config.yml). -# This setting can, in turn, be overridden by CLI parameters (--mode). -default_modes: - -# initial prompt for the project. It will always be given to the LLM upon activating the project -# (contrary to the memories, which are loaded on demand). -initial_prompt: "" - -# time budget (seconds) per tool call for the retrieval of additional symbol information -# such as docstrings or parameter information. -# This overrides the corresponding setting in the global configuration; see the documentation there. -# If null or missing, use the setting from the global configuration. -symbol_info_budget: - -# list of regex patterns which, when matched, mark a memory entry as read‑only. -# Extends the list from the global configuration, merging the two lists. -read_only_memory_patterns: [] - -# list of regex patterns for memories to completely ignore. -# Matching memories will not appear in list_memories or activate_project output -# and cannot be accessed via read_memory or write_memory. -# To access ignored memory files, use the read_file tool on the raw file path. -# Extends the list from the global configuration, merging the two lists. -# Example: ["_archive/.*", "_episodes/.*"] -ignored_memory_patterns: [] - -# advanced configuration option allowing to configure language server-specific options. -# Maps the language key to the options. -# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available. -# No documentation on options means no options are available. -ls_specific_settings: {} From 33bca4cb40eee3b2935c4df445016f0f672a4fc9 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:05:36 +0100 Subject: [PATCH 12/52] remove redundant switch in subscription client factory --- execution/engine/config_factory_federation.go | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/execution/engine/config_factory_federation.go b/execution/engine/config_factory_federation.go index 1c70f3d8a6..dab9c708a0 100644 --- a/execution/engine/config_factory_federation.go +++ b/execution/engine/config_factory_federation.go @@ -451,25 +451,14 @@ func (f *FederationEngineConfigFactory) graphqlDataSourceFactory() (plan.Planner func (f *FederationEngineConfigFactory) subscriptionClient( httpClient *http.Client, streamingClient *http.Client, - subscriptionType SubscriptionType, + _ SubscriptionType, subscriptionClientFactory graphql_datasource.GraphQLSubscriptionClientFactory, ) (graphql_datasource.GraphQLSubscriptionClient, error) { - var graphqlSubscriptionClient graphql_datasource.GraphQLSubscriptionClient - switch subscriptionType { - case SubscriptionTypeGraphQLTransportWS: - graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - f.engineCtx, - graphql_datasource.WithUpgradeClient(httpClient), - graphql_datasource.WithStreamingClient(streamingClient), - ) - default: - // for compatibility reasons we fall back to graphql-ws protocol - graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - f.engineCtx, - graphql_datasource.WithUpgradeClient(httpClient), - graphql_datasource.WithStreamingClient(streamingClient), - ) - } + graphqlSubscriptionClient := subscriptionClientFactory.NewSubscriptionClient( + f.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(streamingClient), + ) ok := graphql_datasource.IsDefaultGraphQLSubscriptionClient(graphqlSubscriptionClient) if !ok { From e986f349e5071a787822e384b40ec160b04f7ce0 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:05:58 +0100 Subject: [PATCH 13/52] Fix SSE header apply order --- .../subscriptionclient/transport/sse_transport.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 9b8a51d0da..0a9b1eed69 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -156,13 +156,13 @@ func buildPOSTRequest(ctx context.Context, req *common.Request, opts common.Opti return nil, fmt.Errorf("create request: %w", err) } + // Add custom headers first, then set SSE-required headers so they cannot be overwritten + maps.Copy(httpReq.Header, opts.Headers) + httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") httpReq.Header.Set("Cache-Control", "no-cache") - // Add custom headers - maps.Copy(httpReq.Header, opts.Headers) - return httpReq, nil } @@ -205,12 +205,12 @@ func buildGETRequest(ctx context.Context, req *common.Request, opts common.Optio return nil, fmt.Errorf("create request: %w", err) } + // Add custom headers first, then set SSE-required headers so they cannot be overwritten + maps.Copy(httpReq.Header, opts.Headers) + httpReq.Header.Set("Accept", "text/event-stream") httpReq.Header.Set("Cache-Control", "no-cache") - // Add custom headers - maps.Copy(httpReq.Header, opts.Headers) - return httpReq, nil } From 4fd554f8a3f4cde7c2f9b4b8dbb34ee8a963d49e Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:06:08 +0100 Subject: [PATCH 14/52] fix defer --- .../subscriptionclient/transport/ws_conn_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 5e2b874446..478cf1fbde 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -32,9 +32,8 @@ func TestWSConnection_Subscribe(t *testing.T) { cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ Query: "subscription { test }", }, handler) - defer cancel() - require.NoError(t, err) + defer cancel() assert.Len(t, proto.SubscribeCalls(), 1) assert.Equal(t, "sub-1", proto.SubscribeCalls()[0].ID) }) From 54060a29cbd591f7317d3dd99262c63e2f4397a0 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:06:30 +0100 Subject: [PATCH 15/52] fix timeout default check --- .../subscriptionclient/protocol/graphql_transport_ws.go | 2 +- .../subscriptionclient/protocol/graphql_ws.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index a6b5f808ec..c9086f097e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -55,7 +55,7 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay initMsg.Payload = payload } timeout := p.AckTimeout - if timeout == 0 { + if timeout <= 0 { timeout = 30 * time.Second } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index ea103326b1..4c1406386c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -46,7 +46,7 @@ func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[ initMsg.Payload = payload } timeout := p.AckTimeout - if timeout == 0 { + if timeout <= 0 { timeout = 30 * time.Second } From 76e9d07883b9c1ff393a2c08eca8bd0a1c25261b Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:06:51 +0100 Subject: [PATCH 16/52] allow pongs during init --- .../subscriptionclient/protocol/graphql_transport_ws.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index c9086f097e..899c5ef000 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -83,6 +83,8 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay return fmt.Errorf("pre-init pong: %w", err) } continue + case gtwsTypePong: + continue default: return fmt.Errorf("%w: got %q", ErrAckNotReceived, ackMessage.Type) } From 162135c6d390917484c09583c9150e4777ac7eda Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 15:07:03 +0100 Subject: [PATCH 17/52] nil check in withlogger --- .../graphql_datasource/graphql_subscription_client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 7ca8ebab64..73c8d4aa03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -67,7 +67,9 @@ func WithStreamingClient(c *http.Client) SubscriptionClientOption { // If not set, logging is disabled (silent operation). func WithLogger(log abstractlogger.Logger) SubscriptionClientOption { return func(cfg *SubscriptionClientConfig) { - cfg.Logger = log + if log != nil { + cfg.Logger = log + } } } From 914d974a36b11fcbc039a1b42afd9803329acc23 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 16:45:12 +0100 Subject: [PATCH 18/52] fixup tests to be clearer --- .../subscriptionclient/transport/sse_conn_test.go | 14 ++++++++------ .../transport/ws_transport_test.go | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go index 147ec5e182..8aaca5fdc5 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go @@ -65,23 +65,22 @@ func TestSSEConnection_ReadLoop(t *testing.T) { )) resp := &http.Response{Body: body} handler, receive := collectingHandler() - conn := newSSEConnection(resp, handler) + wrappedHandler, collect := waitForMessages(handler) + conn := newSSEConnection(resp, wrappedHandler) go conn.readLoop() - var messages []*common.Message - // First message msg1 := receive(t, 1*time.Second) - messages = append(messages, msg1) assert.NotNil(t, msg1.Payload) assert.Equal(t, common.MessageTypeData, msg1.Type) // Complete message msg2 := receive(t, 1*time.Second) - messages = append(messages, msg2) assert.Equal(t, common.MessageTypeComplete, msg2.Type) + // Wait and verify no more messages arrive after complete + messages := collect(100 * time.Millisecond) assert.Len(t, messages, 2, "should receive exactly 2 messages before stopping") }) } @@ -124,7 +123,7 @@ func (r *errorReader) Read(_ []byte) (int, error) { return 0, r.err } -// trackingCloser tracks if Close was called +// trackingCloser tracks if Close was called and forwards to the underlying reader if it implements io.Closer. type trackingCloser struct { io.Reader @@ -133,5 +132,8 @@ type trackingCloser struct { func (c *trackingCloser) Close() error { c.closed.Store(true) + if closer, ok := c.Reader.(io.Closer); ok { + return closer.Close() + } return nil } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index e952fe9512..5943c8c31d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -416,6 +416,7 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { test }"}, opts, handler) if err != nil { + t.Errorf("subscribe error: %v", err) return } defer cancel() From 95ea70e14bd944405fd76c76b44ca4b856eab7cb Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 16:48:33 +0100 Subject: [PATCH 19/52] remove finalizer abstraction --- v2/pkg/engine/resolve/resolve.go | 72 ++++++++++++--------------- v2/pkg/engine/resolve/resolve_test.go | 15 ------ 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index a6321e2643..abd6a5a9ee 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -478,13 +478,9 @@ type subscriptionState struct { lastWriteTime atomic.Int64 } -type subscriptionFinalizer struct { - sub *subscriptionState -} - -func runSubscriptionFinalizers(finalizers []subscriptionFinalizer) { - for _, f := range finalizers { - f.sub.done() +func closeSubs(subs []*subscriptionState) { + for _, s := range subs { + s.done() } } @@ -497,7 +493,7 @@ func (s *subscriptionState) done() { } // complete delivers a "subscription done" signal to the downstream writer. -// Called by handleTriggerComplete, not through finalizers. +// Called by handleTriggerComplete, not through toClose. func (s *subscriptionState) complete() { s.writeMu.Lock() defer s.writeMu.Unlock() @@ -505,7 +501,7 @@ func (s *subscriptionState) complete() { } // error delivers a terminal error payload to the downstream writer. -// Called by handleTriggerError, not through finalizers. +// Called by handleTriggerError, not through toClose. func (s *subscriptionState) error(data []byte) { s.writeMu.Lock() defer s.writeMu.Unlock() @@ -799,13 +795,13 @@ func (r *Resolver) markTriggerInitialized(triggerID uint64) { } // doneTriggerFromUpdater performs cleanup for a trigger from a datasource/updater goroutine. -// It detaches the trigger, runs done finalizers (close completed channels), and cancels the trigger context. +// It detaches the trigger, runs done toClose (close completed channels), and cancels the trigger context. func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { if r.options.Debug { fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID) } r.mu.Lock() - removed, finalizers, cancel, initialized := r.detachTriggerLocked(triggerID) + removed, toClose, cancel, initialized := r.detachTriggerLocked(triggerID) if r.reporter != nil { r.reporter.SubscriptionCountDec(removed) if initialized { @@ -813,7 +809,7 @@ func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { } } r.mu.Unlock() - runSubscriptionFinalizers(finalizers) + closeSubs(toClose) if cancel != nil { cancel() } @@ -867,26 +863,26 @@ func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { } } -func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { +func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { +func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveClient(id int64) (int, []subscriptionFinalizer, []context.CancelFunc, int) { +func (r *Resolver) handleRemoveClient(id int64) (int, []*subscriptionState, []context.CancelFunc, int) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } removed := 0 - finalizers := make([]subscriptionFinalizer, 0) + toClose := make([]*subscriptionState, 0) cancels := make([]context.CancelFunc, 0) triggerDec := 0 idsForConn := r.subscriptionsByConnection[id] @@ -897,7 +893,7 @@ func (r *Resolver) handleRemoveClient(id int64) (int, []subscriptionFinalizer, [ for _, sid := range ids { rem, fz, cancel, initialized := r.removeSubscriptionByID(sid) removed += rem - finalizers = append(finalizers, fz...) + toClose = append(toClose, fz...) if cancel != nil { cancels = append(cancels, cancel) if initialized { @@ -905,12 +901,12 @@ func (r *Resolver) handleRemoveClient(id int64) (int, []subscriptionFinalizer, [ } } } - return removed, finalizers, cancels, triggerDec + return removed, toClose, cancels, triggerDec } // removeSubscriptionByID removes a single subscription by id. // r.mu must be held by the caller. -func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []subscriptionFinalizer, context.CancelFunc, bool) { +func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { s, ok := r.subscriptionsByID[id] if !ok { return 0, nil, nil, false @@ -930,11 +926,9 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []sub return 0, nil, nil, false } - var finalizers []subscriptionFinalizer + var toClose []*subscriptionState if s.removed.CompareAndSwap(false, true) { - finalizers = append(finalizers, subscriptionFinalizer{ - sub: s, - }) + toClose = append(toClose, s) } delete(trig.subscriptions, id) empty := len(trig.subscriptions) == 0 @@ -950,26 +944,24 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []sub initialized = trig.initialized } - return 1, finalizers, cancel, initialized + return 1, toClose, cancel, initialized } // detachTriggerLocked removes all subscriptions for the trigger and removes the trigger from resolver maps. // r.mu must be held by the caller. -func (r *Resolver) detachTriggerLocked(id uint64) (int, []subscriptionFinalizer, context.CancelFunc, bool) { +func (r *Resolver) detachTriggerLocked(id uint64) (int, []*subscriptionState, context.CancelFunc, bool) { trig, ok := r.triggers[id] if !ok { return 0, nil, nil, false } - finalizers := make([]subscriptionFinalizer, 0, len(trig.subscriptions)) + toClose := make([]*subscriptionState, 0, len(trig.subscriptions)) removed := 0 trig.mu.Lock() for sid, s := range trig.subscriptions { if s.removed.CompareAndSwap(false, true) { - finalizers = append(finalizers, subscriptionFinalizer{ - sub: s, - }) + toClose = append(toClose, s) } delete(trig.subscriptions, sid) r.removeSubscriptionIndex(sid) @@ -979,7 +971,7 @@ func (r *Resolver) detachTriggerLocked(id uint64) (int, []subscriptionFinalizer, delete(r.triggers, id) - return removed, finalizers, trig.cancel, trig.initialized + return removed, toClose, trig.cancel, trig.initialized } // pendingWrite holds the context and subscription for a deferred write outside the lock. @@ -1123,15 +1115,15 @@ func (r *Resolver) shutdownResolver() { triggerIDs = append(triggerIDs, id) } - allFinalizers := make([]subscriptionFinalizer, 0) + allToClose := make([]*subscriptionState, 0) cancels := make([]context.CancelFunc, 0, len(triggerIDs)) removedTotal := 0 triggerDec := 0 for _, id := range triggerIDs { - removed, finalizers, cancel, initialized := r.detachTriggerLocked(id) + removed, toClose, cancel, initialized := r.detachTriggerLocked(id) removedTotal += removed - allFinalizers = append(allFinalizers, finalizers...) + allToClose = append(allToClose, toClose...) if cancel != nil { cancels = append(cancels, cancel) } @@ -1152,7 +1144,7 @@ func (r *Resolver) shutdownResolver() { r.subscriptionsByConnection = make(map[int64]map[SubscriptionIdentifier]*subscriptionState) r.mu.Unlock() - runSubscriptionFinalizers(allFinalizers) + closeSubs(allToClose) for _, cancel := range cancels { cancel() } @@ -1201,7 +1193,7 @@ func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { } // Grab the sub before removal so we can send a "complete" frame after releasing r.mu. sub := r.subscriptionsByID[id] - removed, finalizers, cancel, initialized := r.handleCompleteSubscription(id) + removed, toClose, cancel, initialized := r.handleCompleteSubscription(id) if r.reporter != nil { r.reporter.SubscriptionCountDec(removed) if cancel != nil && initialized { @@ -1215,7 +1207,7 @@ func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { if sub != nil { sub.complete() } - runSubscriptionFinalizers(finalizers) + closeSubs(toClose) if cancel != nil { cancel() } @@ -1228,7 +1220,7 @@ func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { r.mu.Unlock() return r.ctx.Err() } - removed, finalizers, cancel, initialized := r.handleRemoveSubscription(id) + removed, toClose, cancel, initialized := r.handleRemoveSubscription(id) if r.reporter != nil { r.reporter.SubscriptionCountDec(removed) if cancel != nil && initialized { @@ -1236,7 +1228,7 @@ func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { } } r.mu.Unlock() - runSubscriptionFinalizers(finalizers) + closeSubs(toClose) if cancel != nil { cancel() } @@ -1249,7 +1241,7 @@ func (r *Resolver) UnsubscribeClient(connectionID int64) error { r.mu.Unlock() return r.ctx.Err() } - removed, finalizers, cancels, triggerDec := r.handleRemoveClient(connectionID) + removed, toClose, cancels, triggerDec := r.handleRemoveClient(connectionID) if r.reporter != nil { r.reporter.SubscriptionCountDec(removed) if triggerDec > 0 { @@ -1257,7 +1249,7 @@ func (r *Resolver) UnsubscribeClient(connectionID int64) error { } } r.mu.Unlock() - runSubscriptionFinalizers(finalizers) + closeSubs(toClose) for _, cancel := range cancels { cancel() } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 170f9cd0bf..c18a0dc418 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -5339,7 +5339,6 @@ type SubscriptionRecorder struct { buf *bytes.Buffer messages []string complete atomic.Bool - done atomic.Bool errors [][]byte mux sync.Mutex onFlush func(p []byte) @@ -5395,20 +5394,6 @@ func (s *SubscriptionRecorder) AwaitComplete(t *testing.T, timeout time.Duration } } -func (s *SubscriptionRecorder) AwaitDone(t *testing.T, timeout time.Duration) { - t.Helper() - deadline := time.Now().Add(timeout) - for { - if s.done.Load() { - return - } - if time.Now().After(deadline) { - t.Fatalf("timed out waiting for done") - } - time.Sleep(time.Millisecond * 10) - } -} - func (s *SubscriptionRecorder) Write(p []byte) (n int, err error) { s.mux.Lock() defer s.mux.Unlock() From 13ccb44f7479b9f62e7a0059e22a97c6e7129ba6 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 2 Apr 2026 16:54:39 +0100 Subject: [PATCH 20/52] go mod tidy in execution --- execution/go.mod | 1 - execution/go.sum | 16 ++++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/execution/go.mod b/execution/go.mod index 85fa1b2538..31ccb51c24 100644 --- a/execution/go.mod +++ b/execution/go.mod @@ -42,7 +42,6 @@ require ( github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect diff --git a/execution/go.sum b/execution/go.sum index 195d3b5f22..5d29f427f1 100644 --- a/execution/go.sum +++ b/execution/go.sum @@ -19,6 +19,7 @@ github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4M github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -43,6 +44,10 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= @@ -59,8 +64,6 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -165,11 +168,13 @@ github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AO github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 h1:TGXDYfDhwFLFTuNuCwkuqXT5aXGz47zcurXLfTBS9w4= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99/go.mod h1:fUuOAUAXUFB/mlSkAaImGeE4A841AKR5dTMWhV4ibxI= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17 h1:GjO2E8LTf3U5JiQJCY4MmlRcAjVt7IvAbWFSgEjQdl8= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17/go.mod h1:7kt64e0LOLMBqOzrfu9PuLRn9cVT9YN1Bb3EennVtws= github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= +github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231 h1:2C8LNFGs8MtI2yPy2/a2WRf9/X2FoMqXlEJkpTjvsTg= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= @@ -177,11 +182,17 @@ github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBi github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= +go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= +go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -266,6 +277,7 @@ gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= From 9d2cbd1479f4746508eaa99ede9617943a3ecb21 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 08:47:36 +0100 Subject: [PATCH 21/52] Error implementation for EngineResultWriter --- execution/graphql/result_writer.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/execution/graphql/result_writer.go b/execution/graphql/result_writer.go index 7e9b3f24a5..e3905207d3 100644 --- a/execution/graphql/result_writer.go +++ b/execution/graphql/result_writer.go @@ -39,8 +39,9 @@ func (e *EngineResultWriter) Heartbeat() error { return nil } -func (e *EngineResultWriter) Error(_ []byte) { - +func (e *EngineResultWriter) Error(data []byte) { + e.buf.Write(data) + e.Flush() } func (e *EngineResultWriter) SetFlushCallback(flushCb func(data []byte)) { From 825d611956fd4c9a6d69b6b4abfc612ec512bdff Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 09:56:48 +0100 Subject: [PATCH 22/52] clarify protocol being implemented --- .../protocol/graphql_transport_ws.go | 10 ++++++---- .../subscriptionclient/protocol/graphql_ws.go | 12 ++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index 899c5ef000..4095e355c4 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -13,6 +13,12 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) +// GraphQLTransportWS implements the graphql-transport-ws protocol. +// See: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md +type GraphQLTransportWS struct { + AckTimeout time.Duration +} + const ( gtwsTypeConnectionInit = "connection_init" gtwsTypeConnectionAck = "connection_ack" @@ -36,10 +42,6 @@ type incomingMessage struct { Payload json.RawMessage `json:"payload,omitempty"` } -type GraphQLTransportWS struct { - AckTimeout time.Duration -} - func NewGraphQLTransportWS() *GraphQLTransportWS { return &GraphQLTransportWS{ AckTimeout: 30 * time.Second, diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index 4c1406386c..5944c09a95 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -13,6 +13,12 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) +// GraphQLWS implements the legacy graphql-ws protocol. +// See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +type GraphQLWS struct { + AckTimeout time.Duration +} + const ( gwsTypeConnectionInit = "connection_init" gwsTypeConnectionAck = "connection_ack" @@ -25,12 +31,6 @@ const ( gwsTypeStop = "stop" ) -// GraphQLWS implements the legacy graphql-ws protocol. -// See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md -type GraphQLWS struct { - AckTimeout time.Duration -} - func NewGraphQLWS() *GraphQLWS { return &GraphQLWS{ AckTimeout: 30 * time.Second, From f949b24e722f5d2a5a7a25b1706ae424a1df255d Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 10:54:50 +0100 Subject: [PATCH 23/52] clean up timeout default placement --- .../graphql_subscription_client.go | 3 +- .../protocol/graphql_transport_ws.go | 23 ++---- .../protocol/graphql_transport_ws_test.go | 6 +- .../subscriptionclient/protocol/graphql_ws.go | 21 +---- .../protocol/graphql_ws_test.go | 6 +- .../subscriptionclient/transport/ws_conn.go | 10 ++- .../transport/ws_transport.go | 79 +++++++++++-------- .../transport/ws_transport_test.go | 23 +----- 8 files changed, 74 insertions(+), 97 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 73c8d4aa03..34c4ccc1d3 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -21,7 +21,7 @@ type SubscriptionClientConfig struct { StreamingClient *http.Client Logger abstractlogger.Logger - // Timeouts + // Timeouts and limits PingInterval time.Duration PingTimeout time.Duration AckTimeout time.Duration @@ -37,7 +37,6 @@ func defaultSubscriptionClientConfig() *SubscriptionClientConfig { PingInterval: 30 * time.Second, PingTimeout: 10 * time.Second, - AckTimeout: 30 * time.Second, } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index 4095e355c4..bd87506f67 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" @@ -15,9 +14,7 @@ import ( // GraphQLTransportWS implements the graphql-transport-ws protocol. // See: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md -type GraphQLTransportWS struct { - AckTimeout time.Duration -} +type GraphQLTransportWS struct{} const ( gtwsTypeConnectionInit = "connection_init" @@ -43,9 +40,7 @@ type incomingMessage struct { } func NewGraphQLTransportWS() *GraphQLTransportWS { - return &GraphQLTransportWS{ - AckTimeout: 30 * time.Second, - } + return &GraphQLTransportWS{} } // Init implements Protocol. @@ -56,21 +51,13 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay if payload != nil { initMsg.Payload = payload } - timeout := p.AckTimeout - if timeout <= 0 { - timeout = 30 * time.Second - } - - ackCtx, ackCancel := context.WithTimeout(ctx, timeout) - defer ackCancel() - - if err := wsjson.Write(ackCtx, conn, initMsg); err != nil { + if err := wsjson.Write(ctx, conn, initMsg); err != nil { return fmt.Errorf("write connection_init: %w", err) } for { var ackMessage incomingMessage - if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { + if err := wsjson.Read(ctx, conn, &ackMessage); err != nil { if errors.Is(err, context.DeadlineExceeded) { return ErrAckTimeout } @@ -81,7 +68,7 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay case gtwsTypeConnectionAck: return nil case gtwsTypePing: - if err := p.Pong(ackCtx, conn); err != nil { + if err := p.Pong(ctx, conn); err != nil { return fmt.Errorf("pre-init pong: %w", err) } continue diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go index ca4649ef53..951109c4ff 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go @@ -53,8 +53,10 @@ func TestGraphQLTransportWS_Init(t *testing.T) { conn := dialGTWS(t, server) - p := &GraphQLTransportWS{AckTimeout: 50 * time.Millisecond} - err := p.Init(t.Context(), conn, nil) + p := NewGraphQLTransportWS() + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + err := p.Init(ctx, conn, nil) require.ErrorIs(t, err, ErrAckTimeout) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index 5944c09a95..4959cb482f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" @@ -15,9 +14,7 @@ import ( // GraphQLWS implements the legacy graphql-ws protocol. // See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md -type GraphQLWS struct { - AckTimeout time.Duration -} +type GraphQLWS struct{} const ( gwsTypeConnectionInit = "connection_init" @@ -32,9 +29,7 @@ const ( ) func NewGraphQLWS() *GraphQLWS { - return &GraphQLWS{ - AckTimeout: 30 * time.Second, - } + return &GraphQLWS{} } // Init implements Protocol. @@ -45,21 +40,13 @@ func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[ if payload != nil { initMsg.Payload = payload } - timeout := p.AckTimeout - if timeout <= 0 { - timeout = 30 * time.Second - } - - ackCtx, ackCancel := context.WithTimeout(ctx, timeout) - defer ackCancel() - - if err := wsjson.Write(ackCtx, conn, initMsg); err != nil { + if err := wsjson.Write(ctx, conn, initMsg); err != nil { return fmt.Errorf("write connection_init: %w", err) } for { var ackMessage incomingMessage - if err := wsjson.Read(ackCtx, conn, &ackMessage); err != nil { + if err := wsjson.Read(ctx, conn, &ackMessage); err != nil { if errors.Is(err, context.DeadlineExceeded) { return ErrAckTimeout } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go index 04e5fe51a7..66c85b98f7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go @@ -53,8 +53,10 @@ func TestGraphQLWS_Init(t *testing.T) { conn := dialGWS(t, server) - p := &GraphQLWS{AckTimeout: 50 * time.Millisecond} - err := p.Init(t.Context(), conn, nil) + p := NewGraphQLWS() + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + err := p.Init(ctx, conn, nil) require.ErrorIs(t, err, ErrAckTimeout) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 4d9b50512a..90824b5caa 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -19,7 +19,6 @@ var ( ErrSubscriptionExists = errors.New("subscription ID already exists") defaultWriteTimeout = 5 * time.Second - defaultReadLimit = int64(1024 * 1024) // 1MB ) type wsConnectionOptions struct { @@ -98,7 +97,10 @@ func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Req c.subs[id] = handler c.subsMu.Unlock() - if err := c.protocol.Subscribe(ctx, c.conn, id, req); err != nil { + subscribeCtx, subscribeCancel := context.WithTimeout(ctx, c.writeTimeout) + defer subscribeCancel() + + if err := c.protocol.Subscribe(subscribeCtx, c.conn, id, req); err != nil { c.log.Error("wsConnection.Subscribe", abstractlogger.String("id", id), abstractlogger.Error(err), @@ -253,8 +255,8 @@ func (c *wsConnection) subCount() int { } // sendPing sends a protocol-level ping message and records the timestamp. -func (c *wsConnection) sendPing(timeout time.Duration) error { - pingCtx, cancel := context.WithTimeout(c.ctx, timeout) +func (c *wsConnection) sendPing() error { + pingCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) defer cancel() err := c.protocol.Ping(pingCtx, c.conn) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 973eb4ffbe..1a3fbb4863 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -20,7 +20,12 @@ import ( // ErrDialFailed indicates that the WebSocket dial (TCP + HTTP upgrade) failed. // The underlying cause is available via errors.Unwrap. -var ErrDialFailed = errors.New("websocket dial failed") +var ( + ErrDialFailed = errors.New("websocket dial failed") + + defaultReadLimit = int64(1024 * 1024) // 1MB + defaultAckTimeout = 30 * time.Second +) // ErrInitFailed indicates that the GraphQL protocol init (connection_init / // connection_ack handshake) failed after a successful WebSocket dial. The @@ -44,16 +49,36 @@ func (e ErrInvalidSubprotocol) Error() string { // WSTransportOptions configures a WSTransport. type WSTransportOptions struct { + // UpgradeClient is the HTTP client used for the WebSocket upgrade request. UpgradeClient *http.Client - Logger abstractlogger.Logger - PingInterval time.Duration - PingTimeout time.Duration - AckTimeout time.Duration - WriteTimeout time.Duration - ReadLimit int64 + + // Logger is the logger used for transport and connection-level events. + Logger abstractlogger.Logger + + // ReadLimit is the maximum message size in bytes the WebSocket connection + // will accept. Default: 1MB. + ReadLimit int64 + + // PingInterval is how often the transport sends a ping to each connection. + // Zero disables pinging. + PingInterval time.Duration + + // PingTimeout is how long a connection may go without a pong before it is + // considered dead. Only meaningful when PingInterval is set. + PingTimeout time.Duration + + // AckTimeout is the maximum time to wait for a connection_ack during the + // protocol init handshake. Passed to the protocol at construction. + // Default: 30s. + AckTimeout time.Duration + + // WriteTimeout is the deadline applied to each WebSocket write (subscribe, + // unsubscribe, ping, pong). Passed to each connection. Default: 5s. + WriteTimeout time.Duration + // IdleTimeout is the duration a connection stays open after its last - // subscription is removed, allowing new subscriptions to reuse it - // without re-dialing. Zero means close immediately. + // subscription is removed, allowing new subscriptions to reuse it without + // re-dialing. Zero means close immediately. IdleTimeout time.Duration } @@ -83,14 +108,17 @@ func NewWSTransport(ctx context.Context, opts WSTransportOptions) *WSTransport { if opts.UpgradeClient == nil { opts.UpgradeClient = http.DefaultClient } + if opts.Logger == nil { opts.Logger = abstractlogger.NoopLogger } + if opts.ReadLimit <= 0 { opts.ReadLimit = defaultReadLimit } - if opts.WriteTimeout <= 0 { - opts.WriteTimeout = defaultWriteTimeout + + if opts.AckTimeout <= 0 { + opts.AckTimeout = defaultAckTimeout } t := &WSTransport{ @@ -146,7 +174,7 @@ func (t *WSTransport) pingLoop() { continue } - if err := conn.sendPing(t.opts.WriteTimeout); err != nil { + if err := conn.sendPing(); err != nil { t.opts.Logger.Debug("wsTransport.pingLoop", abstractlogger.String("action", "ping_failed"), abstractlogger.Error(err), @@ -157,16 +185,6 @@ func (t *WSTransport) pingLoop() { } } -// ReadLimit returns the configured read limit. -func (t *WSTransport) ReadLimit() int64 { - return t.opts.ReadLimit -} - -// WriteTimeout returns the configured write timeout for new connections. -func (t *WSTransport) WriteTimeout() time.Duration { - return t.opts.WriteTimeout -} - func (t *WSTransport) ConnCount() int { t.mu.Lock() defer t.mu.Unlock() @@ -262,7 +280,10 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) return nil, err } - if err := proto.Init(ctx, wsConn, opts.InitPayload); err != nil { + initCtx, initCancel := context.WithTimeout(ctx, t.opts.AckTimeout) + defer initCancel() + + if err := proto.Init(initCtx, wsConn, opts.InitPayload); err != nil { t.opts.Logger.Error("wsTransport.dial", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.String("error", "protocol init failed"), @@ -299,17 +320,9 @@ func (t *WSTransport) negotiateSubprotocol(requested common.WSSubprotocol, accep switch common.WSSubprotocol(accepted) { case common.SubprotocolGraphQLTransportWS: - p := protocol.NewGraphQLTransportWS() - if t.opts.AckTimeout > 0 { - p.AckTimeout = t.opts.AckTimeout - } - return p, nil + return protocol.NewGraphQLTransportWS(), nil case common.SubprotocolGraphQLWS: - p := protocol.NewGraphQLWS() - if t.opts.AckTimeout > 0 { - p.AckTimeout = t.opts.AckTimeout - } - return p, nil + return protocol.NewGraphQLWS(), nil default: return nil, ErrInvalidSubprotocol(accepted) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 5943c8c31d..edf588c16f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -1068,7 +1068,7 @@ func TestWSTransport_Defaults(t *testing.T) { tr := NewWSTransport(t.Context(), WSTransportOptions{}) - assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) }) t.Run("applies default read limit for zero value", func(t *testing.T) { @@ -1076,7 +1076,7 @@ func TestWSTransport_Defaults(t *testing.T) { tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 0}) - assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) }) t.Run("overrides read limit when provided", func(t *testing.T) { @@ -1084,7 +1084,7 @@ func TestWSTransport_Defaults(t *testing.T) { tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 2 * 1024 * 1024}) - assert.Equal(t, int64(2*1024*1024), tr.ReadLimit()) + assert.Equal(t, int64(2*1024*1024), tr.opts.ReadLimit) }) t.Run("ignores negative read limit", func(t *testing.T) { @@ -1092,24 +1092,9 @@ func TestWSTransport_Defaults(t *testing.T) { tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: -1}) - assert.Equal(t, defaultReadLimit, tr.ReadLimit()) + assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) }) - t.Run("applies default write timeout", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{}) - - assert.Equal(t, defaultWriteTimeout, tr.WriteTimeout()) - }) - - t.Run("overrides write timeout when provided", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{WriteTimeout: 10 * time.Second}) - - assert.Equal(t, 10*time.Second, tr.WriteTimeout()) - }) } // Test helpers From 3879306890ba9a4d07e82153a2b278a915eae5fc Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 11:18:39 +0100 Subject: [PATCH 24/52] renaming message types for clarity, adding doc comments to protocol --- .../subscriptionclient/common/message.go | 3 +++ .../subscriptionclient/common/options.go | 12 ++++++--- .../protocol/graphql_transport_ws.go | 6 ++--- .../subscriptionclient/protocol/graphql_ws.go | 6 ++--- .../subscriptionclient/protocol/protocol.go | 26 +++++++++++++------ .../subscriptionclient/transport/ws_conn.go | 2 +- .../transport/ws_conn_test.go | 22 ++++++++-------- 7 files changed, 47 insertions(+), 30 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go index fc121bc509..9e02e8197c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -23,6 +23,7 @@ func (t MessageType) IsTerminal() bool { return t == MessageTypeError || t == MessageTypeComplete || t == MessageTypeConnectionError } +// Message is a single subscription event delivered to a Handler. type Message struct { Type MessageType Payload *ExecutionResult @@ -33,12 +34,14 @@ type Message struct { // transport's read goroutine; a slow handler blocks message delivery. type Handler func(msg *Message) +// ExecutionResult is the GraphQL response payload for data and error messages. type ExecutionResult struct { Data json.RawMessage `json:"data,omitempty"` Errors json.RawMessage `json:"errors,omitempty"` Extensions json.RawMessage `json:"extensions,omitempty"` } +// Request is a GraphQL operation sent to the server when subscribing. type Request struct { Query string `json:"query"` OperationName string `json:"operationName,omitempty"` diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go index ab706d02f0..9177be370f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go @@ -4,21 +4,24 @@ import ( "net/http" ) +// TransportType selects the subscription transport mechanism. type TransportType string const ( - TransportWS TransportType = "ws" - TransportSSE TransportType = "sse" + TransportWS TransportType = "ws" // WebSocket connection + TransportSSE TransportType = "sse" // Server-Sent Events over HTTP ) +// WSSubprotocol selects the GraphQL-over-WebSocket subprotocol. type WSSubprotocol string const ( SubprotocolAuto WSSubprotocol = "" // Auto, negotiated with the server - SubprotocolGraphQLTransportWS WSSubprotocol = "graphql-transport-ws" // Modern subprotocol - SubprotocolGraphQLWS WSSubprotocol = "graphql-ws" // Legacy subprotocol, deprecated + SubprotocolGraphQLTransportWS WSSubprotocol = "graphql-transport-ws" // Modern protocol from The Guild + SubprotocolGraphQLWS WSSubprotocol = "graphql-ws" // Legacy Apollo protocol, deprecated ) +// Subprotocols returns the WebSocket subprotocol strings to offer during the upgrade handshake. func (s WSSubprotocol) Subprotocols() []string { switch s { case SubprotocolAuto: @@ -40,6 +43,7 @@ const ( SSEMethodGET SSEMethod = "GET" // GET with query parameters (traditional SSE) ) +// Options configures a single subscription request (endpoint, headers, transport selection). type Options struct { Endpoint string Headers http.Header diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index bd87506f67..23cbcf119c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -97,7 +97,7 @@ func (p *GraphQLTransportWS) Pong(ctx context.Context, conn *websocket.Conn) err } // Read implements Protocol. -func (p *GraphQLTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*Message, error) { +func (p *GraphQLTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { var raw incomingMessage if err := wsjson.Read(ctx, conn, &raw); err != nil { return nil, fmt.Errorf("read message: %w", err) @@ -125,8 +125,8 @@ func (p *GraphQLTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Co return wsjson.Write(ctx, conn, msg) } -func (p *GraphQLTransportWS) decode(raw incomingMessage) (*Message, error) { - msg := &Message{ +func (p *GraphQLTransportWS) decode(raw incomingMessage) (*WireMessage, error) { + msg := &WireMessage{ ID: raw.ID, } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index 4959cb482f..854154bfd3 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -92,7 +92,7 @@ func (p *GraphQLWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id st } // Read implements Protocol. -func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*Message, error) { +func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { var raw incomingMessage if err := wsjson.Read(ctx, conn, &raw); err != nil { return nil, fmt.Errorf("read message: %w", err) @@ -115,8 +115,8 @@ func (p *GraphQLWS) Pong(ctx context.Context, conn *websocket.Conn) error { return nil } -func (p *GraphQLWS) decode(raw incomingMessage) (*Message, error) { - msg := &Message{ +func (p *GraphQLWS) decode(raw incomingMessage) (*WireMessage, error) { + msg := &WireMessage{ ID: raw.ID, } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go index 4daddd2c24..70f84ddf10 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -9,17 +9,24 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) +// Protocol defines the message framing and behaviour used on a WS connection. type Protocol interface { + // Init performs the connection handshake with the server. Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error + // Subscribe starts a subscription for the given operation. Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error + // Unsubscribe ends a subscription. Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error - Read(ctx context.Context, conn *websocket.Conn) (*Message, error) + // Read blocks until the next message arrives and decodes it. + Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) + // Ping requests a liveness check from the server. No-op for protocols that don't support it. Ping(ctx context.Context, conn *websocket.Conn) error + // Pong responds to a server liveness check. No-op for protocols that don't support it. Pong(ctx context.Context, conn *websocket.Conn) error } @@ -29,14 +36,17 @@ var ( ErrConnectionError = errors.New("connection error from server") ) -type Message struct { +// WireMessage is a decoded wire-level protocol message. +// It is different from the common message format because it still contains the ID and internal type, +// which is not exposed to consumers. +type WireMessage struct { ID string - Type MessageType + Type WireMessageType Payload *common.ExecutionResult Err error } -func (m *Message) IntoClientMessage() *common.Message { +func (m *WireMessage) IntoClientMessage() *common.Message { switch m.Type { case MessageData: return &common.Message{Type: common.MessageTypeData, Payload: m.Payload} @@ -52,18 +62,18 @@ func (m *Message) IntoClientMessage() *common.Message { } } -// MessageType identifies the message type. -type MessageType int +// WireMessageType identifies the message type. +type WireMessageType int const ( - MessageData MessageType = iota + MessageData WireMessageType = iota MessageError MessageComplete MessagePing MessagePong ) -func (t MessageType) String() string { +func (t WireMessageType) String() string { switch t { case MessageData: return "data" diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 90824b5caa..a5737eb40f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -193,7 +193,7 @@ func (c *wsConnection) readLoop() { } } -func (c *wsConnection) dispatch(msg *protocol.Message) { +func (c *wsConnection) dispatch(msg *protocol.WireMessage) { c.subsMu.RLock() handler, exists := c.subs[msg.ID] c.subsMu.RUnlock() diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 478cf1fbde..e2d81a6907 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -99,7 +99,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { go wsc.readLoop() - proto.PushMessage(&protocol.Message{ + proto.PushMessage(&protocol.WireMessage{ ID: "sub-1", Type: protocol.MessageData, Payload: &common.ExecutionResult{Data: json.RawMessage(`{"value": 42}`)}, @@ -123,7 +123,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { go wsc.readLoop() - proto.PushMessage(&protocol.Message{ + proto.PushMessage(&protocol.WireMessage{ ID: "sub-1", Type: protocol.MessageComplete, }) @@ -141,7 +141,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { go wsc.readLoop() - proto.PushMessage(&protocol.Message{Type: protocol.MessagePing}) + proto.PushMessage(&protocol.WireMessage{Type: protocol.MessagePing}) assert.Eventually(t, func() bool { return proto.PongCount() > 0 @@ -162,13 +162,13 @@ func TestWSConnection_ReadLoop(t *testing.T) { go wsc.readLoop() - proto.PushMessage(&protocol.Message{ + proto.PushMessage(&protocol.WireMessage{ ID: "unknown-sub", Type: protocol.MessageData, Payload: &common.ExecutionResult{Data: json.RawMessage(`{"wrong": true}`)}, }) - proto.PushMessage(&protocol.Message{ + proto.PushMessage(&protocol.WireMessage{ ID: "sub-1", Type: protocol.MessageData, Payload: &common.ExecutionResult{Data: json.RawMessage(`{"right": true}`)}, @@ -508,10 +508,10 @@ func TestWSConnection_WriteTimeout(t *testing.T) { go wsc.readLoop() // Send ping (will trigger slow pong) - proto.PushMessage(&protocol.Message{Type: protocol.MessagePing}) + proto.PushMessage(&protocol.WireMessage{Type: protocol.MessagePing}) // Send data message right after - proto.PushMessage(&protocol.Message{ + proto.PushMessage(&protocol.WireMessage{ ID: "sub-1", Type: protocol.MessageData, Payload: &common.ExecutionResult{Data: json.RawMessage(`{"test": true}`)}, @@ -621,7 +621,7 @@ type mockProtocol struct { unsubscribeDelay time.Duration pongDelay time.Duration - messages chan *protocol.Message + messages chan *protocol.WireMessage } type subscribeCall struct { @@ -631,7 +631,7 @@ type subscribeCall struct { func newMockProtocol() *mockProtocol { return &mockProtocol{ - messages: make(chan *protocol.Message, 100), + messages: make(chan *protocol.WireMessage, 100), } } @@ -663,7 +663,7 @@ func (m *mockProtocol) Unsubscribe(ctx context.Context, conn *websocket.Conn, id return nil } -func (m *mockProtocol) Read(ctx context.Context, conn *websocket.Conn) (*protocol.Message, error) { +func (m *mockProtocol) Read(ctx context.Context, conn *websocket.Conn) (*protocol.WireMessage, error) { select { case msg := <-m.messages: return msg, nil @@ -691,7 +691,7 @@ func (m *mockProtocol) Pong(ctx context.Context, conn *websocket.Conn) error { return nil } -func (m *mockProtocol) PushMessage(msg *protocol.Message) { +func (m *mockProtocol) PushMessage(msg *protocol.WireMessage) { m.messages <- msg } From 3412101cc00778bfc9560578c10d81c9db59fa7b Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 11:33:41 +0100 Subject: [PATCH 25/52] lighter ints for enums --- .../graphql_datasource/subscriptionclient/common/message.go | 2 +- .../graphql_datasource/subscriptionclient/protocol/protocol.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go index 9e02e8197c..d99f12a470 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -8,7 +8,7 @@ import ( var ErrConnectionClosed = errors.New("connection closed") // MessageType identifies the kind of message delivered on a subscription channel. -type MessageType int +type MessageType uint8 const ( MessageTypeUnknown MessageType = iota diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go index 70f84ddf10..3b5378e505 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -63,7 +63,7 @@ func (m *WireMessage) IntoClientMessage() *common.Message { } // WireMessageType identifies the message type. -type WireMessageType int +type WireMessageType uint8 const ( MessageData WireMessageType = iota From fb6d012f88c1e791139b0059e1d2582199814d61 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 11:33:58 +0100 Subject: [PATCH 26/52] use consts directly in map function --- .../graphql_datasource/graphql_subscription_client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 34c4ccc1d3..62e082c6aa 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -257,9 +257,9 @@ func convertToClientOptions(options GraphQLSubscriptionOptions) (client.Options, // mapWSSubprotocol maps the string subprotocol to the client.WSSubprotocol type. func mapWSSubprotocol(proto string) client.WSSubprotocol { switch proto { - case "graphql-ws": + case string(client.SubprotocolGraphQLWS): return client.SubprotocolGraphQLWS - case "graphql-transport-ws": + case string(client.SubprotocolGraphQLTransportWS): return client.SubprotocolGraphQLTransportWS default: return client.SubprotocolAuto From 6e167e2191d44c1704638045c497e12c153df4ec Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 11:36:26 +0100 Subject: [PATCH 27/52] clarify re-export file --- .../graphql_datasource/subscriptionclient/exports.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go index 0d24753b3e..51db35954e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -6,7 +6,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" ) -// Re-export common types for single-import convenience. +// Re-exports from internal sub-packages (common, protocol, transport) so that +// callers can import this package alone instead of reaching into internals. type ( Message = common.Message From cf2504b25f976f8a47c1d037bd43ddc7ab9dea87 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 11:48:44 +0100 Subject: [PATCH 28/52] fix toctou race and add a test thanks @ysmolski, good eyes --- .../subscriptionclient/transport/ws_conn.go | 5 +- .../transport/ws_conn_test.go | 60 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index a5737eb40f..b3e12542e7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -83,12 +83,13 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConne } func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request, handler common.Handler) (func(), error) { + c.subsMu.Lock() + if c.closed.Load() { + c.subsMu.Unlock() return nil, common.ErrConnectionClosed } - c.subsMu.Lock() - if _, exists := c.subs[id]; exists { c.subsMu.Unlock() return nil, ErrSubscriptionExists diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index e2d81a6907..5cf5ae8997 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "runtime" "net/http/httptest" "strings" "sync" @@ -67,6 +68,65 @@ func TestWSConnection_Subscribe(t *testing.T) { assert.ErrorIs(t, err, common.ErrConnectionClosed) }) + t.Run("returns ErrConnectionClosed when shutdown races before lock acquisition", func(t *testing.T) { + t.Parallel() + + // Test for TOCTOU between closed.Load() and subsMu.Lock() + // in subscribe(). + // + // Without a closed re-check under the lock, this sequence is possible: + // 1. subscribe: closed.Load() → false (check) + // 2. shutdown: closed.CAS(false,true) (invalidates the check) + // 3. shutdown: swaps c.subs, dispatches errors to old handlers + // 4. subscribe: subsMu.Lock(), adds handler to NEW empty map (use) + // 5. subscribe: protocol.Subscribe → nil (mock doesn't check conn) + // 6. subscribe returns (cancel, nil) — looks successful + // + // The handler is now orphaned: closed=true prevents any future + // shutdown from running (CAS fails), so the handler will never + // receive a terminal message. This test forces the exact interleaving + // deterministically. + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + + // Step 1: Hold subsMu so subscribe() blocks after its closed check. + wsc.subsMu.Lock() + + subscribeResult := make(chan error, 1) + go func() { + // Passes closed.Load() (still false), then blocks on subsMu.Lock(). + _, err := wsc.subscribe(context.Background(), "sub-race", &common.Request{}, func(_ *common.Message) {}) + subscribeResult <- err + }() + + // Let the goroutine reach the blocked Lock(). + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + // Step 2: Simulate what shutdown does first — set closed=true. + // We set it directly rather than calling shutdown() because shutdown + // also needs subsMu (which we hold), and we want to control ordering. + wsc.closed.Store(true) + + // Step 3: Release the lock. subscribe() now enters the critical + // section with a stale view: it saw closed=false, but closed is now true. + wsc.subsMu.Unlock() + + select { + case err := <-subscribeResult: + // With the fix: subscribe re-checks closed under the lock → ErrConnectionClosed. + // Without the fix: subscribe succeeds (nil error) — the handler is orphaned + // because closed=true means no future shutdown() can deliver a terminal message. + require.ErrorIs(t, err, common.ErrConnectionClosed, + "subscribe must detect closed state under the lock; without this check "+ + "the handler is orphaned (closed=true prevents future shutdown)") + case <-time.After(time.Second): + t.Fatal("subscribe did not return within timeout") + } + }) + t.Run("returns error when protocol subscribe fails", func(t *testing.T) { t.Parallel() From 0a28c72e7ae0febcb6e60d7a30ad1bc7482f7fe7 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 21:34:26 +0100 Subject: [PATCH 29/52] move filter buffer handling inside, use pool --- v2/pkg/engine/resolve/resolve.go | 6 +- .../resolve/resolver_subscription_test.go | 2 - v2/pkg/engine/resolve/subscription_filter.go | 18 ++-- .../resolve/subscription_filter_test.go | 91 ++++++------------- 4 files changed, 44 insertions(+), 73 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index abd6a5a9ee..0762c31811 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -445,7 +445,6 @@ type trigger struct { id uint64 cancel context.CancelFunc subscriptions map[SubscriptionIdentifier]*subscriptionState - updateBuf *bytes.Buffer // initialized is set to true when the trigger is started and initialized initialized bool updater *subscriptionUpdater @@ -743,7 +742,6 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) trig = &trigger{ id: triggerID, subscriptions: make(map[SubscriptionIdentifier]*subscriptionState), - updateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), cancel: cancel, updater: updater, } @@ -1007,7 +1005,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if s.ctx.ctx.Err() != nil { continue } - skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) if err != nil { filterErrors = append(filterErrors, pendingFilterError{s.ctx, err, s.resolve.Response, s.writer}) continue @@ -1054,7 +1052,7 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie s, ok := trig.subscriptions[subIdentifier] if ok { if s.ctx.ctx.Err() == nil { - skip, err := s.resolve.Filter.SkipEvent(s.ctx, data, trig.updateBuf) + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) if err != nil { filterErr = &pendingFilterError{s.ctx, err, s.resolve.Response, s.writer} } else if !skip { diff --git a/v2/pkg/engine/resolve/resolver_subscription_test.go b/v2/pkg/engine/resolve/resolver_subscription_test.go index dbaed8901d..0421316454 100644 --- a/v2/pkg/engine/resolve/resolver_subscription_test.go +++ b/v2/pkg/engine/resolve/resolver_subscription_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "errors" "io" @@ -232,7 +231,6 @@ func TestResolver_HeartbeatError_DoesNotDeadlockOnUnsubscribe(t *testing.T) { id: triggerID, cancel: func() {}, subscriptions: map[SubscriptionIdentifier]*subscriptionState{subID: s}, - updateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), } resolver.subscriptionsByID[subID] = s resolver.subscriptionsByConnection[subID.ConnectionID] = map[SubscriptionIdentifier]*subscriptionState{subID: s} diff --git a/v2/pkg/engine/resolve/subscription_filter.go b/v2/pkg/engine/resolve/subscription_filter.go index 25e0df5d65..2bffe480a0 100644 --- a/v2/pkg/engine/resolve/subscription_filter.go +++ b/v2/pkg/engine/resolve/subscription_filter.go @@ -11,6 +11,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) type SubscriptionFilter struct { @@ -25,14 +26,14 @@ type SubscriptionFieldFilter struct { Values []InputTemplate } -func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buffer) (bool, error) { +func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte) (bool, error) { if f == nil { return false, nil } if f.And != nil { for _, filter := range f.And { - skip, err := filter.SkipEvent(ctx, data, buf) + skip, err := filter.SkipEvent(ctx, data) if err != nil { return false, err } @@ -47,7 +48,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf if f.Or != nil { for _, filter := range f.Or { - skip, err := filter.SkipEvent(ctx, data, buf) + skip, err := filter.SkipEvent(ctx, data) if err != nil { return false, err } @@ -61,7 +62,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf } if f.Not != nil { - skip, err := f.Not.SkipEvent(ctx, data, buf) + skip, err := f.Not.SkipEvent(ctx, data) if err != nil { return false, err } @@ -69,7 +70,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf } if f.In != nil { - return f.In.SkipEvent(ctx, data, buf) + return f.In.SkipEvent(ctx, data) } return false, nil @@ -84,7 +85,7 @@ var ( ErrInvalidSubscriptionFilterTemplate = errors.New("invalid subscription filter template") ) -func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buffer) (bool, error) { +func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte) (bool, error) { if f == nil { return false, nil } @@ -94,6 +95,11 @@ func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte, buf *byte return true, nil } + // Scratch buffer for rendering filter template values. Pooled to avoid + // per-event allocations. + buf := pool.BytesBuffer.Get() + defer pool.BytesBuffer.Put(buf) + for i := range f.Values { buf.Reset() err := f.Values[i].Render(ctx, nil, buf) diff --git a/v2/pkg/engine/resolve/subscription_filter_test.go b/v2/pkg/engine/resolve/subscription_filter_test.go index 807ef706e6..98c6eb8068 100644 --- a/v2/pkg/engine/resolve/subscription_filter_test.go +++ b/v2/pkg/engine/resolve/subscription_filter_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "testing" "github.com/stretchr/testify/assert" @@ -31,9 +30,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -58,9 +56,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"false"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -85,18 +82,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"true"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"true"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -121,9 +116,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":1.13}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":1.13}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -148,18 +142,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"1.13"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":1.13}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":1.13}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"1.13"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -184,9 +176,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":49}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":49}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -211,18 +202,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"49"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":49}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":49}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"49"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -247,9 +236,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"9.77"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":8.01}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -274,9 +262,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":123}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":321}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -301,9 +288,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -328,9 +314,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":["a","b"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -355,9 +340,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":[1,"2"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":2}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -382,9 +366,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":["a","b","c"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -411,9 +394,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"b"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"b"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -440,9 +422,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"b"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -488,9 +469,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -530,9 +510,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -557,9 +536,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":true,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -586,9 +564,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"id":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":1,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -613,9 +590,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":null,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -661,9 +637,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"d","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -709,9 +684,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","unused":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -757,9 +731,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -805,9 +778,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","unused":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -853,9 +825,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c","fourth":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c","fourth":1}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -901,9 +872,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -956,9 +926,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c","fourth":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c1","fourth":1}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) From 83073b07975f080010f1d727e9759228d192041f Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 21:38:33 +0100 Subject: [PATCH 30/52] doc comments improvement --- v2/pkg/engine/resolve/resolve.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 0762c31811..a279937ac4 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -57,6 +57,8 @@ type Resolver struct { options ResolverOptions maxConcurrency chan struct{} + // mu protects: shutdown, triggers, subscriptionsByID, subscriptionsByConnection. + // Lock ordering: subscriptionUpdater.mu > Resolver.mu > trigger.mu (then subscriptionState.writeMu outside those locks). mu sync.Mutex shutdown bool triggers map[uint64]*trigger @@ -440,7 +442,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe return resp, err } +// trigger groups subscriptions that share a data source and input. type trigger struct { + // mu protects subscriptions and initialized. + // Uses snapshot-and-release: held only during map access, released before I/O. mu sync.RWMutex id uint64 cancel context.CancelFunc @@ -461,6 +466,7 @@ func (t *trigger) subscriptionIds() map[context.Context]SubscriptionIdentifier { return subs } +// subscriptionState tracks a single active subscription. type subscriptionState struct { triggerID uint64 resolve *GraphQLSubscription @@ -469,7 +475,9 @@ type subscriptionState struct { id SubscriptionIdentifier heartbeat bool completed chan struct{} - writeMu sync.Mutex + // writeMu protects all writes to writer (Complete, Error, Write, Flush, Heartbeat). + // Paired with the removed atomic to prevent writes after removal. + writeMu sync.Mutex // removed guards against writes after the subscription has been removed. // Uses CompareAndSwap to prevent double-close of the completed channel. removed atomic.Bool @@ -667,6 +675,7 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti return nil } +// addSubscriptionIndex updates the by-ID and by-connection indexes. Resolver.mu must be held. func (r *Resolver) addSubscriptionIndex(s *subscriptionState) { id := s.id r.subscriptionsByID[id] = s @@ -678,6 +687,7 @@ func (r *Resolver) addSubscriptionIndex(s *subscriptionState) { byConn[id] = s } +// removeSubscriptionIndex removes from the by-ID and by-connection indexes. Resolver.mu must be held. func (r *Resolver) removeSubscriptionIndex(id SubscriptionIdentifier) { delete(r.subscriptionsByID, id) byConn, ok := r.subscriptionsByConnection[id.ConnectionID] @@ -1456,7 +1466,15 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip return input, nil } +// subscriptionUpdater implements SubscriptionUpdater, the callback API for data sources. type subscriptionUpdater struct { + // mu serves two roles: + // + // 1. Event serialization gate -- held across the entire Update() call including + // wg.Wait(), ensuring event A fully completes before event B begins. + // + // 2. Lifecycle guard -- the done flag prevents callbacks after Done() has torn down + // the trigger. Every method checks done || ctx.Err() under the lock before proceeding. mu sync.Mutex done bool debug bool From 039b62d16bbaa2a574547ae13fa05f5db4678c29 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 21:39:01 +0100 Subject: [PATCH 31/52] clean up complicated 3/4 value returns --- v2/pkg/engine/resolve/resolve.go | 129 +++++++++++++++++++------------ 1 file changed, 79 insertions(+), 50 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index a279937ac4..7f0cdd7b6e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -809,17 +809,17 @@ func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID) } r.mu.Lock() - removed, toClose, cancel, initialized := r.detachTriggerLocked(triggerID) + res := r.detachTriggerLocked(triggerID) if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if initialized { + r.reporter.SubscriptionCountDec(res.removed) + if res.initialized { r.reporter.TriggerCountDec(1) } } r.mu.Unlock() - closeSubs(toClose) - if cancel != nil { - cancel() + closeSubs(res.toClose) + if res.triggerCancel != nil { + res.triggerCancel() } } @@ -871,21 +871,21 @@ func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { } } -func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { +func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) removeResult { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { +func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) removeResult { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) } return r.removeSubscriptionByID(id) } -func (r *Resolver) handleRemoveClient(id int64) (int, []*subscriptionState, []context.CancelFunc, int) { +func (r *Resolver) handleRemoveClient(id int64) removeClientResult { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } @@ -899,31 +899,36 @@ func (r *Resolver) handleRemoveClient(id int64) (int, []*subscriptionState, []co ids = append(ids, sid) } for _, sid := range ids { - rem, fz, cancel, initialized := r.removeSubscriptionByID(sid) - removed += rem - toClose = append(toClose, fz...) - if cancel != nil { - cancels = append(cancels, cancel) - if initialized { + res := r.removeSubscriptionByID(sid) + removed += res.removed + toClose = append(toClose, res.toClose...) + if res.triggerCancel != nil { + cancels = append(cancels, res.triggerCancel) + if res.initialized { triggerDec++ } } } - return removed, toClose, cancels, triggerDec + return removeClientResult{ + removed: removed, + toClose: toClose, + cancels: cancels, + triggerDec: triggerDec, + } } // removeSubscriptionByID removes a single subscription by id. // r.mu must be held by the caller. -func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []*subscriptionState, context.CancelFunc, bool) { +func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) removeResult { s, ok := r.subscriptionsByID[id] if !ok { - return 0, nil, nil, false + return removeResult{} } trig, ok := r.triggers[s.triggerID] if !ok { r.removeSubscriptionIndex(id) - return 0, nil, nil, false + return removeResult{} } trig.mu.Lock() @@ -931,7 +936,7 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []*su if !ok { trig.mu.Unlock() r.removeSubscriptionIndex(id) - return 0, nil, nil, false + return removeResult{} } var toClose []*subscriptionState @@ -944,23 +949,28 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) (int, []*su r.removeSubscriptionIndex(id) - var cancel context.CancelFunc + var triggerCancel context.CancelFunc initialized := false if empty { delete(r.triggers, trig.id) - cancel = trig.cancel + triggerCancel = trig.cancel initialized = trig.initialized } - return 1, toClose, cancel, initialized + return removeResult{ + removed: 1, + toClose: toClose, + triggerCancel: triggerCancel, + initialized: initialized, + } } // detachTriggerLocked removes all subscriptions for the trigger and removes the trigger from resolver maps. // r.mu must be held by the caller. -func (r *Resolver) detachTriggerLocked(id uint64) (int, []*subscriptionState, context.CancelFunc, bool) { +func (r *Resolver) detachTriggerLocked(id uint64) removeResult { trig, ok := r.triggers[id] if !ok { - return 0, nil, nil, false + return removeResult{} } toClose := make([]*subscriptionState, 0, len(trig.subscriptions)) @@ -979,7 +989,26 @@ func (r *Resolver) detachTriggerLocked(id uint64) (int, []*subscriptionState, co delete(r.triggers, id) - return removed, toClose, trig.cancel, trig.initialized + return removeResult{ + removed: removed, + toClose: toClose, + triggerCancel: trig.cancel, + initialized: trig.initialized, + } +} + +type removeResult struct { + removed int + toClose []*subscriptionState + triggerCancel context.CancelFunc // non-nil if trigger became empty + initialized bool // whether the removed trigger was initialized +} + +type removeClientResult struct { + removed int + toClose []*subscriptionState + cancels []context.CancelFunc + triggerDec int } // pendingWrite holds the context and subscription for a deferred write outside the lock. @@ -1129,13 +1158,13 @@ func (r *Resolver) shutdownResolver() { triggerDec := 0 for _, id := range triggerIDs { - removed, toClose, cancel, initialized := r.detachTriggerLocked(id) - removedTotal += removed - allToClose = append(allToClose, toClose...) - if cancel != nil { - cancels = append(cancels, cancel) + res := r.detachTriggerLocked(id) + removedTotal += res.removed + allToClose = append(allToClose, res.toClose...) + if res.triggerCancel != nil { + cancels = append(cancels, res.triggerCancel) } - if initialized { + if res.initialized { triggerDec++ } } @@ -1201,10 +1230,10 @@ func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { } // Grab the sub before removal so we can send a "complete" frame after releasing r.mu. sub := r.subscriptionsByID[id] - removed, toClose, cancel, initialized := r.handleCompleteSubscription(id) + res := r.handleCompleteSubscription(id) if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if cancel != nil && initialized { + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerCancel != nil && res.initialized { r.reporter.TriggerCountDec(1) } } @@ -1215,9 +1244,9 @@ func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { if sub != nil { sub.complete() } - closeSubs(toClose) - if cancel != nil { - cancel() + closeSubs(res.toClose) + if res.triggerCancel != nil { + res.triggerCancel() } return nil } @@ -1228,17 +1257,17 @@ func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { r.mu.Unlock() return r.ctx.Err() } - removed, toClose, cancel, initialized := r.handleRemoveSubscription(id) + res := r.handleRemoveSubscription(id) if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if cancel != nil && initialized { + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerCancel != nil && res.initialized { r.reporter.TriggerCountDec(1) } } r.mu.Unlock() - closeSubs(toClose) - if cancel != nil { - cancel() + closeSubs(res.toClose) + if res.triggerCancel != nil { + res.triggerCancel() } return nil } @@ -1249,16 +1278,16 @@ func (r *Resolver) UnsubscribeClient(connectionID int64) error { r.mu.Unlock() return r.ctx.Err() } - removed, toClose, cancels, triggerDec := r.handleRemoveClient(connectionID) + res := r.handleRemoveClient(connectionID) if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if triggerDec > 0 { - r.reporter.TriggerCountDec(triggerDec) + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerDec > 0 { + r.reporter.TriggerCountDec(res.triggerDec) } } r.mu.Unlock() - closeSubs(toClose) - for _, cancel := range cancels { + closeSubs(res.toClose) + for _, cancel := range res.cancels { cancel() } return nil From 1f492df8b26667cda5f5745acd016dc25f0c15f5 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 22:32:01 +0100 Subject: [PATCH 32/52] resolver method renaming and cleanup --- v2/pkg/engine/resolve/resolve.go | 100 ++++++++----------------------- 1 file changed, 26 insertions(+), 74 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 7f0cdd7b6e..e646dd764b 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -675,8 +675,8 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti return nil } -// addSubscriptionIndex updates the by-ID and by-connection indexes. Resolver.mu must be held. -func (r *Resolver) addSubscriptionIndex(s *subscriptionState) { +// registerSubscriptionLocked updates the by-ID and by-connection indexes. +func (r *Resolver) registerSubscriptionLocked(s *subscriptionState) { id := s.id r.subscriptionsByID[id] = s byConn, ok := r.subscriptionsByConnection[id.ConnectionID] @@ -687,8 +687,8 @@ func (r *Resolver) addSubscriptionIndex(s *subscriptionState) { byConn[id] = s } -// removeSubscriptionIndex removes from the by-ID and by-connection indexes. Resolver.mu must be held. -func (r *Resolver) removeSubscriptionIndex(id SubscriptionIdentifier) { +// unregisterSubscriptionLocked removes from the by-ID and by-connection indexes. +func (r *Resolver) unregisterSubscriptionLocked(id SubscriptionIdentifier) { delete(r.subscriptionsByID, id) byConn, ok := r.subscriptionsByConnection[id.ConnectionID] if !ok { @@ -700,8 +700,8 @@ func (r *Resolver) removeSubscriptionIndex(id SubscriptionIdentifier) { } } -// handleAddSubscription must be called with r.mu held. -func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { +// addSubscriptionLocked registers a new subscription under the given trigger. +func (r *Resolver) addSubscriptionLocked(triggerID uint64, add *addSubscription) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID) } @@ -729,7 +729,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) trig.mu.Lock() trig.subscriptions[add.id] = s trig.mu.Unlock() - r.addSubscriptionIndex(s) + r.registerSubscriptionLocked(s) // Execute the startup hooks in a goroutine to avoid holding the lock. // On failure, executeStartupHooks calls UnsubscribeSubscription to clean up. go func() { @@ -760,7 +760,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) trig.subscriptions[add.id] = s trig.mu.Unlock() updater.subsFn = trig.subscriptionIds - r.addSubscriptionIndex(s) + r.registerSubscriptionLocked(s) if r.reporter != nil { r.reporter.SubscriptionCountInc(1) @@ -871,21 +871,7 @@ func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { } } -func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) removeResult { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) - } - return r.removeSubscriptionByID(id) -} - -func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) removeResult { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) - } - return r.removeSubscriptionByID(id) -} - -func (r *Resolver) handleRemoveClient(id int64) removeClientResult { +func (r *Resolver) removeClientLocked(id int64) removeClientResult { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } @@ -899,7 +885,7 @@ func (r *Resolver) handleRemoveClient(id int64) removeClientResult { ids = append(ids, sid) } for _, sid := range ids { - res := r.removeSubscriptionByID(sid) + res := r.removeSubscriptionLocked(sid) removed += res.removed toClose = append(toClose, res.toClose...) if res.triggerCancel != nil { @@ -917,9 +903,9 @@ func (r *Resolver) handleRemoveClient(id int64) removeClientResult { } } -// removeSubscriptionByID removes a single subscription by id. +// removeSubscriptionLocked removes a single subscription by id. // r.mu must be held by the caller. -func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) removeResult { +func (r *Resolver) removeSubscriptionLocked(id SubscriptionIdentifier) removeResult { s, ok := r.subscriptionsByID[id] if !ok { return removeResult{} @@ -927,7 +913,7 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) removeResul trig, ok := r.triggers[s.triggerID] if !ok { - r.removeSubscriptionIndex(id) + r.unregisterSubscriptionLocked(id) return removeResult{} } @@ -935,7 +921,7 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) removeResul _, ok = trig.subscriptions[id] if !ok { trig.mu.Unlock() - r.removeSubscriptionIndex(id) + r.unregisterSubscriptionLocked(id) return removeResult{} } @@ -947,7 +933,7 @@ func (r *Resolver) removeSubscriptionByID(id SubscriptionIdentifier) removeResul empty := len(trig.subscriptions) == 0 trig.mu.Unlock() - r.removeSubscriptionIndex(id) + r.unregisterSubscriptionLocked(id) var triggerCancel context.CancelFunc initialized := false @@ -982,7 +968,7 @@ func (r *Resolver) detachTriggerLocked(id uint64) removeResult { toClose = append(toClose, s) } delete(trig.subscriptions, sid) - r.removeSubscriptionIndex(sid) + r.unregisterSubscriptionLocked(sid) removed++ } trig.mu.Unlock() @@ -1011,11 +997,6 @@ type removeClientResult struct { triggerDec int } -// pendingWrite holds the context and subscription for a deferred write outside the lock. -type pendingSubscriptionWrite struct { - sub *subscriptionState -} - type pendingFilterError struct { ctx *Context err error @@ -1037,7 +1018,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { fmt.Printf("resolver:trigger:update:%d\n", id) } - var pending []pendingSubscriptionWrite + var pending []*subscriptionState var filterErrors []pendingFilterError trig.mu.Lock() for _, s := range trig.subscriptions { @@ -1052,7 +1033,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if skip { continue } - pending = append(pending, pendingSubscriptionWrite{s}) + pending = append(pending, s) } trig.mu.Unlock() @@ -1061,12 +1042,12 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } var wg sync.WaitGroup - for _, pw := range pending { - if pw.sub.removed.Load() { + for _, sub := range pending { + if sub.removed.Load() { continue } wg.Go(func() { - r.executeSubscriptionUpdate(pw.sub.ctx, pw.sub, data) + r.executeSubscriptionUpdate(sub.ctx, sub, data) }) } wg.Wait() @@ -1222,42 +1203,13 @@ type SubscriptionIdentifier struct { SubscriptionID int64 } -func (r *Resolver) CompleteSubscription(id SubscriptionIdentifier) error { - r.mu.Lock() - if r.shutdown { - r.mu.Unlock() - return r.ctx.Err() - } - // Grab the sub before removal so we can send a "complete" frame after releasing r.mu. - sub := r.subscriptionsByID[id] - res := r.handleCompleteSubscription(id) - if r.reporter != nil { - r.reporter.SubscriptionCountDec(res.removed) - if res.triggerCancel != nil && res.initialized { - r.reporter.TriggerCountDec(1) - } - } - r.mu.Unlock() - // Send "complete" to the downstream writer under writeMu. - // This ensures any in-flight data write finishes before the complete is sent, - // matching the old behavior where the worker goroutine called sub.complete(). - if sub != nil { - sub.complete() - } - closeSubs(res.toClose) - if res.triggerCancel != nil { - res.triggerCancel() - } - return nil -} - func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { r.mu.Lock() if r.shutdown { r.mu.Unlock() return r.ctx.Err() } - res := r.handleRemoveSubscription(id) + res := r.removeSubscriptionLocked(id) if r.reporter != nil { r.reporter.SubscriptionCountDec(res.removed) if res.triggerCancel != nil && res.initialized { @@ -1278,7 +1230,7 @@ func (r *Resolver) UnsubscribeClient(connectionID int64) error { r.mu.Unlock() return r.ctx.Err() } - res := r.handleRemoveClient(connectionID) + res := r.removeClientLocked(connectionID) if r.reporter != nil { r.reporter.SubscriptionCountDec(res.removed) if res.triggerDec > 0 { @@ -1366,7 +1318,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ r.mu.Unlock() return r.ctx.Err() } - r.handleAddSubscription(triggerID, &addSubscription{ + r.addSubscriptionLocked(triggerID, &addSubscription{ ctx: ctx, input: input, resolve: subscription, @@ -1459,7 +1411,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G r.mu.Unlock() return r.ctx.Err() } - r.handleAddSubscription(triggerID, &addSubscription{ + r.addSubscriptionLocked(triggerID, &addSubscription{ ctx: ctx, input: input, resolve: subscription, @@ -1500,7 +1452,7 @@ type subscriptionUpdater struct { // mu serves two roles: // // 1. Event serialization gate -- held across the entire Update() call including - // wg.Wait(), ensuring event A fully completes before event B begins. + // wg.Wait(), ensuring event A fully completes before event B begins. // // 2. Lifecycle guard -- the done flag prevents callbacks after Done() has torn down // the trigger. Every method checks done || ctx.Err() under the lock before proceeding. From 1336b634c8ab1b98908e9f16e27831c17a894022 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 23:21:42 +0100 Subject: [PATCH 33/52] clean up locking a bit --- v2/pkg/engine/resolve/resolve.go | 82 ++++++++++++++------------------ 1 file changed, 35 insertions(+), 47 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index e646dd764b..754c4ce482 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -676,7 +676,10 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti } // registerSubscriptionLocked updates the by-ID and by-connection indexes. -func (r *Resolver) registerSubscriptionLocked(s *subscriptionState) { +func (r *Resolver) registerSubscriptionLocked(trig *trigger, s *subscriptionState) { + trig.mu.Lock() + trig.subscriptions[s.id] = s + trig.mu.Unlock() id := s.id r.subscriptionsByID[id] = s byConn, ok := r.subscriptionsByConnection[id.ConnectionID] @@ -700,8 +703,13 @@ func (r *Resolver) unregisterSubscriptionLocked(id SubscriptionIdentifier) { } } -// addSubscriptionLocked registers a new subscription under the given trigger. -func (r *Resolver) addSubscriptionLocked(triggerID uint64, add *addSubscription) { +// addSubscription registers a new subscription under the given trigger. +func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.shutdown { + return r.ctx.Err() + } if r.options.Debug { fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID) } @@ -726,16 +734,13 @@ func (r *Resolver) addSubscriptionLocked(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:subscription:added:%d:%d\n", triggerID, add.id.SubscriptionID) } // Register first so startup hooks can deliver initial data via UpdateSubscription. - trig.mu.Lock() - trig.subscriptions[add.id] = s - trig.mu.Unlock() - r.registerSubscriptionLocked(s) + r.registerSubscriptionLocked(trig, s) // Execute the startup hooks in a goroutine to avoid holding the lock. // On failure, executeStartupHooks calls UnsubscribeSubscription to clean up. go func() { _ = r.executeStartupHooks(add, trig.updater) }() - return + return nil } if r.options.Debug { @@ -756,11 +761,8 @@ func (r *Resolver) addSubscriptionLocked(triggerID uint64, add *addSubscription) updater: updater, } r.triggers[triggerID] = trig - trig.mu.Lock() - trig.subscriptions[add.id] = s - trig.mu.Unlock() updater.subsFn = trig.subscriptionIds - r.registerSubscriptionLocked(s) + r.registerSubscriptionLocked(trig, s) if r.reporter != nil { r.reporter.SubscriptionCountInc(1) @@ -793,6 +795,7 @@ func (r *Resolver) addSubscriptionLocked(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:started:%d\n", triggerID) } }() + return nil } // markTriggerInitialized marks a trigger as initialized under the lock. @@ -871,7 +874,12 @@ func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { } } -func (r *Resolver) removeClientLocked(id int64) removeClientResult { +func (r *Resolver) removeClient(id int64) removeClientResult { + r.mu.Lock() + defer r.mu.Unlock() + if r.shutdown { + return removeClientResult{} + } if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } @@ -895,12 +903,19 @@ func (r *Resolver) removeClientLocked(id int64) removeClientResult { } } } - return removeClientResult{ + res := removeClientResult{ removed: removed, toClose: toClose, cancels: cancels, triggerDec: triggerDec, } + if r.reporter != nil { + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerDec > 0 { + r.reporter.TriggerCountDec(res.triggerDec) + } + } + return res } // removeSubscriptionLocked removes a single subscription by id. @@ -1225,19 +1240,7 @@ func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { } func (r *Resolver) UnsubscribeClient(connectionID int64) error { - r.mu.Lock() - if r.shutdown { - r.mu.Unlock() - return r.ctx.Err() - } - res := r.removeClientLocked(connectionID) - if r.reporter != nil { - r.reporter.SubscriptionCountDec(res.removed) - if res.triggerDec > 0 { - r.reporter.TriggerCountDec(res.triggerDec) - } - } - r.mu.Unlock() + res := r.removeClient(connectionID) closeSubs(res.toClose) for _, cancel := range res.cancels { cancel() @@ -1313,12 +1316,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ completed := make(chan struct{}) - r.mu.Lock() - if r.shutdown { - r.mu.Unlock() - return r.ctx.Err() - } - r.addSubscriptionLocked(triggerID, &addSubscription{ + if err := r.addSubscription(triggerID, &addSubscription{ ctx: ctx, input: input, resolve: subscription, @@ -1327,8 +1325,9 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ completed: completed, sourceName: subscription.Trigger.SourceName, headers: headers, - }) - r.mu.Unlock() + }); err != nil { + return err + } // This will immediately block until one of the following conditions is met: select { @@ -1402,16 +1401,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) - r.mu.Lock() - if err := ctx.ctx.Err(); err != nil { - r.mu.Unlock() - return err - } - if r.shutdown { - r.mu.Unlock() - return r.ctx.Err() - } - r.addSubscriptionLocked(triggerID, &addSubscription{ + return r.addSubscription(triggerID, &addSubscription{ ctx: ctx, input: input, resolve: subscription, @@ -1421,8 +1411,6 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G sourceName: subscription.Trigger.SourceName, headers: headers, }) - r.mu.Unlock() - return nil } func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscription) (input []byte, err error) { From 9ce6cd2484f18d20a4e7c80f7ee81a752308f85e Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 8 Apr 2026 23:40:20 +0100 Subject: [PATCH 34/52] reusable helpers for filtering and snapshotting and getting to make mutexes clearer --- v2/pkg/engine/resolve/resolve.go | 188 ++++++++++++++++--------------- 1 file changed, 99 insertions(+), 89 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 754c4ce482..b7b733fd41 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -444,14 +444,14 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe // trigger groups subscriptions that share a data source and input. type trigger struct { - // mu protects subscriptions and initialized. + // mu protects subscriptions. // Uses snapshot-and-release: held only during map access, released before I/O. mu sync.RWMutex id uint64 cancel context.CancelFunc subscriptions map[SubscriptionIdentifier]*subscriptionState - // initialized is set to true when the trigger is started and initialized - initialized bool + // initialized is set to true when the trigger is started and initialized. + initialized atomic.Bool updater *subscriptionUpdater } @@ -466,6 +466,70 @@ func (t *trigger) subscriptionIds() map[context.Context]SubscriptionIdentifier { return subs } +// snapshotSubscriptions returns a point-in-time copy of all subscriptions. +func (t *trigger) snapshotSubscriptions() []*subscriptionState { + t.mu.RLock() + defer t.mu.RUnlock() + subs := make([]*subscriptionState, 0, len(t.subscriptions)) + for _, s := range t.subscriptions { + subs = append(subs, s) + } + return subs +} + +// evalFilter runs SkipEvent for a single subscription. Must be called under t.mu. +func (t *trigger) evalFilter(s *subscriptionState, data []byte) (*subscriptionState, *pendingFilterError) { + if s.ctx.ctx.Err() != nil { + return nil, nil + } + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) + if err != nil { + fe := pendingFilterError{s.ctx, err, s.resolve.Response, s.writer} + return nil, &fe + } + if skip { + return nil, nil + } + return s, nil +} + +// filterSubscriptions evaluates SkipEvent for every active subscription and +// partitions them into pending updates and filter errors. +func (t *trigger) filterSubscriptions(data []byte) ([]*subscriptionState, []pendingFilterError) { + t.mu.Lock() + defer t.mu.Unlock() + + var subs []*subscriptionState + var filterErrors []pendingFilterError + + for _, s := range t.subscriptions { + pending, filterErr := t.evalFilter(s, data) + if pending != nil { + subs = append(subs, pending) + } + if filterErr != nil { + filterErrors = append(filterErrors, *filterErr) + } + } + + return subs, filterErrors +} + +// filterSubscription evaluates SkipEvent for a single subscription by ID. +func (t *trigger) filterSubscription(id SubscriptionIdentifier, data []byte) (*subscriptionState, *pendingFilterError) { + t.mu.Lock() + defer t.mu.Unlock() + + s, ok := t.subscriptions[id] + if !ok { + return nil, nil + } + + sub, filterErr := t.evalFilter(s, data) + + return sub, filterErr +} + // subscriptionState tracks a single active subscription. type subscriptionState struct { triggerID uint64 @@ -635,18 +699,6 @@ func (r *Resolver) executeSubscriptionHeartbeat(sub *subscriptionState) { } } -func (r *Resolver) handleTriggerInitialized(triggerID uint64) { - trig, ok := r.triggers[triggerID] - if !ok { - return - } - trig.initialized = true - - if r.reporter != nil { - r.reporter.TriggerCountInc(1) - } -} - type StartupHookContext struct { Context context.Context Updater func(data []byte) @@ -798,11 +850,23 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error return nil } -// markTriggerInitialized marks a trigger as initialized under the lock. -func (r *Resolver) markTriggerInitialized(triggerID uint64) { +func (r *Resolver) getTrigger(id uint64) (*trigger, bool) { r.mu.Lock() defer r.mu.Unlock() - r.handleTriggerInitialized(triggerID) + trig, ok := r.triggers[id] + return trig, ok +} + +// markTriggerInitialized marks a trigger as initialized and reports it. +func (r *Resolver) markTriggerInitialized(triggerID uint64) { + trig, ok := r.getTrigger(triggerID) + if !ok { + return + } + trig.initialized.Store(true) + if r.reporter != nil { + r.reporter.TriggerCountInc(1) + } } // doneTriggerFromUpdater performs cleanup for a trigger from a datasource/updater goroutine. @@ -829,19 +893,11 @@ func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { // handleTriggerComplete delivers a complete signal to all subscriptions on the trigger. // Does NOT detach the trigger — Done() does that. func (r *Resolver) handleTriggerComplete(triggerID uint64) { - r.mu.Lock() - trig, ok := r.triggers[triggerID] + trig, ok := r.getTrigger(triggerID) if !ok { - r.mu.Unlock() return } - trig.mu.Lock() - subs := make([]*subscriptionState, 0, len(trig.subscriptions)) - for _, s := range trig.subscriptions { - subs = append(subs, s) - } - trig.mu.Unlock() - r.mu.Unlock() + subs := trig.snapshotSubscriptions() for _, s := range subs { if !s.removed.Load() { @@ -853,19 +909,11 @@ func (r *Resolver) handleTriggerComplete(triggerID uint64) { // handleTriggerError delivers a terminal error to all subscriptions on the trigger, // bypassing the resolve pipeline. Does NOT detach the trigger — Done() does that. func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { - r.mu.Lock() - trig, ok := r.triggers[triggerID] + trig, ok := r.getTrigger(triggerID) if !ok { - r.mu.Unlock() return } - trig.mu.Lock() - subs := make([]*subscriptionState, 0, len(trig.subscriptions)) - for _, s := range trig.subscriptions { - subs = append(subs, s) - } - trig.mu.Unlock() - r.mu.Unlock() + subs := trig.snapshotSubscriptions() for _, s := range subs { if !s.removed.Load() { @@ -955,7 +1003,7 @@ func (r *Resolver) removeSubscriptionLocked(id SubscriptionIdentifier) removeRes if empty { delete(r.triggers, trig.id) triggerCancel = trig.cancel - initialized = trig.initialized + initialized = trig.initialized.Load() } return removeResult{ @@ -994,7 +1042,7 @@ func (r *Resolver) detachTriggerLocked(id uint64) removeResult { removed: removed, toClose: toClose, triggerCancel: trig.cancel, - initialized: trig.initialized, + initialized: trig.initialized.Load(), } } @@ -1023,9 +1071,7 @@ type pendingFilterError struct { // The lock is released before performing I/O to avoid deadlocks when executeSubscriptionUpdate // calls AsyncUnsubscribeSubscription on flush failure. func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { - r.mu.Lock() - trig, ok := r.triggers[id] - r.mu.Unlock() + trig, ok := r.getTrigger(id) if !ok { return } @@ -1033,31 +1079,14 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { fmt.Printf("resolver:trigger:update:%d\n", id) } - var pending []*subscriptionState - var filterErrors []pendingFilterError - trig.mu.Lock() - for _, s := range trig.subscriptions { - if s.ctx.ctx.Err() != nil { - continue - } - skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) - if err != nil { - filterErrors = append(filterErrors, pendingFilterError{s.ctx, err, s.resolve.Response, s.writer}) - continue - } - if skip { - continue - } - pending = append(pending, s) - } - trig.mu.Unlock() + subs, filterErrors := trig.filterSubscriptions(data) for _, fe := range filterErrors { r.asyncErrorWriter.WriteError(fe.ctx, fe.err, fe.response, fe.writer) } var wg sync.WaitGroup - for _, sub := range pending { + for _, sub := range subs { if sub.removed.Load() { continue } @@ -1070,9 +1099,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { // handleUpdateSubscription sends data to a single subscription using snapshot-and-release. func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { - r.mu.Lock() - trig, ok := r.triggers[id] - r.mu.Unlock() + trig, ok := r.getTrigger(id) if !ok { return } @@ -1081,42 +1108,26 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) } - var target *subscriptionState - var filterErr *pendingFilterError - trig.mu.Lock() - s, ok := trig.subscriptions[subIdentifier] - if ok { - if s.ctx.ctx.Err() == nil { - skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) - if err != nil { - filterErr = &pendingFilterError{s.ctx, err, s.resolve.Response, s.writer} - } else if !skip { - target = s - } - } - } - trig.mu.Unlock() + sub, filterErr := trig.filterSubscription(subIdentifier, data) if filterErr != nil { r.asyncErrorWriter.WriteError(filterErr.ctx, filterErr.err, filterErr.response, filterErr.writer) } - if target != nil && !target.removed.Load() { - r.executeSubscriptionUpdate(target.ctx, target, data) + if sub != nil && !sub.removed.Load() { + r.executeSubscriptionUpdate(sub.ctx, sub, data) } } func (r *Resolver) heartbeatTriggerSubscriptions(id uint64) { - r.mu.Lock() - trig, ok := r.triggers[id] - r.mu.Unlock() + trig, ok := r.getTrigger(id) if !ok { return } - targets := make([]*subscriptionState, 0, len(trig.subscriptions)) - trig.mu.RLock() - for _, s := range trig.subscriptions { + subs := trig.snapshotSubscriptions() + targets := make([]*subscriptionState, 0, len(subs)) + for _, s := range subs { if !s.heartbeat || s.removed.Load() { continue } @@ -1125,7 +1136,6 @@ func (r *Resolver) heartbeatTriggerSubscriptions(id uint64) { } targets = append(targets, s) } - trig.mu.RUnlock() for _, s := range targets { r.executeSubscriptionHeartbeat(s) From 6c54a760175909dbd5a40b293c57f431e8fee233 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 9 Apr 2026 21:30:03 +0100 Subject: [PATCH 35/52] more comments, rename misleading resolver field, encapsulate writer mutex usage --- v2/pkg/engine/resolve/resolve.go | 75 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b7b733fd41..86cad1f479 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -68,8 +68,11 @@ type Resolver struct { allowedErrorExtensionFields map[string]struct{} allowedErrorFields map[string]struct{} - reporter Reporter - asyncErrorWriter AsyncErrorWriter + reporter Reporter + + // errorFormatter is a function provided by the router that formats Go errors into GraphQL responses and writes them to a writer. + // It's not really async, and it needs to be done under the writer's mutex. This is complex and should be resolved in the future. + errorFormatter AsyncErrorWriter propagateSubgraphErrors bool propagateSubgraphStatusCodes bool @@ -94,7 +97,7 @@ type Resolver struct { } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { - r.asyncErrorWriter = w + r.errorFormatter = w } type tools struct { @@ -271,7 +274,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { subscriptionsByID: make(map[SubscriptionIdentifier]*subscriptionState), subscriptionsByConnection: make(map[int64]map[SubscriptionIdentifier]*subscriptionState), reporter: options.Reporter, - asyncErrorWriter: options.AsyncErrorWriter, + errorFormatter: options.AsyncErrorWriter, allowedErrorExtensionFields: allowedExtensionFields, allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, @@ -484,7 +487,7 @@ func (t *trigger) evalFilter(s *subscriptionState, data []byte) (*subscriptionSt } skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) if err != nil { - fe := pendingFilterError{s.ctx, err, s.resolve.Response, s.writer} + fe := pendingFilterError{s.ctx, err, s.resolve.Response, s} return nil, &fe } if skip { @@ -579,6 +582,27 @@ func (s *subscriptionState) error(data []byte) { s.writer.Error(data) } +// writeError delivers a formatted error to the downstream writer under writeMu. +func (s *subscriptionState) writeError(w AsyncErrorWriter, ctx *Context, err error, response *GraphQLResponse) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.removed.Load() { + return + } + w.WriteError(ctx, err, response, s.writer) +} + +// sendHeartbeat sends a keep-alive frame to the downstream writer under writeMu. +// @TODO: this is bad, see ENG-9356 +func (s *subscriptionState) sendHeartbeat() error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.removed.Load() { + return nil + } + return s.writer.Heartbeat() +} + func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscriptionState, sharedInput []byte) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID) @@ -598,11 +622,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscript if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveArena) - sub.writeMu.Lock() - if !sub.removed.Load() { - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) - } - sub.writeMu.Unlock() + sub.writeError(r.errorFormatter, resolveCtx, err, sub.resolve.Response) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) } @@ -614,11 +634,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscript if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { r.resolveArenaPool.Release(resolveArena) - sub.writeMu.Lock() - if !sub.removed.Load() { - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) - } - sub.writeMu.Unlock() + sub.writeError(r.errorFormatter, resolveCtx, err, sub.resolve.Response) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) } @@ -637,7 +653,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscript if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + r.errorFormatter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -680,19 +696,10 @@ func (r *Resolver) executeSubscriptionHeartbeat(sub *subscriptionState) { return } - sub.writeMu.Lock() - - if sub.removed.Load() { - sub.writeMu.Unlock() - return - } - - if err := sub.writer.Heartbeat(); err != nil { - sub.writeMu.Unlock() + if err := sub.sendHeartbeat(); err != nil { _ = r.UnsubscribeSubscription(sub.id) return } - sub.writeMu.Unlock() if r.reporter != nil { r.reporter.SubscriptionUpdateSent() @@ -704,7 +711,7 @@ type StartupHookContext struct { Updater func(data []byte) } -func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscriptionUpdater) error { +func (r *Resolver) executeStartupHooks(add *addSubscription, sub *subscriptionState, updater *subscriptionUpdater) error { hook, ok := add.resolve.Trigger.Source.(HookableSubscriptionDataSource) if ok { hookCtx := StartupHookContext{ @@ -719,7 +726,7 @@ func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscripti if r.options.Debug { fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) } - r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) + sub.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) _ = r.UnsubscribeSubscription(add.id) return err } @@ -790,7 +797,7 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error // Execute the startup hooks in a goroutine to avoid holding the lock. // On failure, executeStartupHooks calls UnsubscribeSubscription to clean up. go func() { - _ = r.executeStartupHooks(add, trig.updater) + _ = r.executeStartupHooks(add, s, trig.updater) }() return nil } @@ -826,7 +833,7 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error } // This is blocking so the startup hook can decide if a subscription should be started or not by returning an error - err := r.executeStartupHooks(add, trig.updater) + err := r.executeStartupHooks(add, s, trig.updater) if err != nil { return } @@ -836,7 +843,7 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) } - r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) + s.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) r.doneTriggerFromUpdater(triggerID) return } @@ -1064,7 +1071,7 @@ type pendingFilterError struct { ctx *Context err error response *GraphQLResponse - writer SubscriptionResponseWriter + sub *subscriptionState } // handleTriggerUpdate sends data to all subscriptions of a trigger using snapshot-and-release. @@ -1082,7 +1089,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { subs, filterErrors := trig.filterSubscriptions(data) for _, fe := range filterErrors { - r.asyncErrorWriter.WriteError(fe.ctx, fe.err, fe.response, fe.writer) + fe.sub.writeError(r.errorFormatter, fe.ctx, fe.err, fe.response) } var wg sync.WaitGroup @@ -1111,7 +1118,7 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie sub, filterErr := trig.filterSubscription(subIdentifier, data) if filterErr != nil { - r.asyncErrorWriter.WriteError(filterErr.ctx, filterErr.err, filterErr.response, filterErr.writer) + filterErr.sub.writeError(r.errorFormatter, filterErr.ctx, filterErr.err, filterErr.response) } if sub != nil && !sub.removed.Load() { From 7d96ff25dac3a1fe660fe80ca9803de742cc24d7 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 9 Apr 2026 21:40:59 +0100 Subject: [PATCH 36/52] fix gci --- .../subscriptionclient/transport/ws_conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 5cf5ae8997..3cd213b39d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "net/http" - "runtime" "net/http/httptest" + "runtime" "strings" "sync" "testing" From e37eeb18f6ea57640c4b516cac625b361df60124 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 9 Apr 2026 22:28:39 +0100 Subject: [PATCH 37/52] comments and connection ids --- .../transport/sse_transport.go | 7 +++-- v2/pkg/engine/resolve/resolve.go | 28 +++++++++++-------- v2/pkg/engine/resolve/resolve_test.go | 2 +- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 0a9b1eed69..9b5630630e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -73,9 +73,12 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts abstractlogger.String("method", string(method)), ) + // Derive a request context that outlives ctx (via WithoutCancel) so we can + // control its lifetime independently. Two AfterFunc registrations tie the + // request to both shutdown paths: + // - t.ctx cancel: transport-wide shutdown, tears down all in-flight requests. + // - ctx cancel: individual subscription cancelled by the caller. requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) - - // Attach cancel to transport context context.AfterFunc(t.ctx, requestCancel) context.AfterFunc(ctx, requestCancel) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 86cad1f479..1ac2688983 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -27,10 +27,16 @@ const ( DefaultHeartbeatInterval = 5 * time.Second ) -// ConnectionIDs is used to create unique connection IDs for each subscription -// Whenever a new connection is created, use this to generate a new ID -// It is public because it can be used in more high level packages to instantiate a new connection -var ConnectionIDs atomic.Int64 +// ConnectionID identifies a client connection for subscription routing. +type ConnectionID int64 + +// connectionIDs is the monotonic counter backing NewConnectionID. +var connectionIDs atomic.Int64 + +// NewConnectionID returns a unique ConnectionID via an atomic increment. +func NewConnectionID() ConnectionID { + return ConnectionID(connectionIDs.Add(1)) +} type Reporter interface { // SubscriptionUpdateSent called when a new subscription update is sent @@ -63,7 +69,7 @@ type Resolver struct { shutdown bool triggers map[uint64]*trigger subscriptionsByID map[SubscriptionIdentifier]*subscriptionState - subscriptionsByConnection map[int64]map[SubscriptionIdentifier]*subscriptionState + subscriptionsByConnection map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState allowedErrorExtensionFields map[string]struct{} allowedErrorFields map[string]struct{} @@ -272,7 +278,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, triggers: make(map[uint64]*trigger), subscriptionsByID: make(map[SubscriptionIdentifier]*subscriptionState), - subscriptionsByConnection: make(map[int64]map[SubscriptionIdentifier]*subscriptionState), + subscriptionsByConnection: make(map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState), reporter: options.Reporter, errorFormatter: options.AsyncErrorWriter, allowedErrorExtensionFields: allowedExtensionFields, @@ -929,7 +935,7 @@ func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { } } -func (r *Resolver) removeClient(id int64) removeClientResult { +func (r *Resolver) removeClient(id ConnectionID) removeClientResult { r.mu.Lock() defer r.mu.Unlock() if r.shutdown { @@ -1191,7 +1197,7 @@ func (r *Resolver) shutdownResolver() { r.triggers = make(map[uint64]*trigger) r.subscriptionsByID = make(map[SubscriptionIdentifier]*subscriptionState) - r.subscriptionsByConnection = make(map[int64]map[SubscriptionIdentifier]*subscriptionState) + r.subscriptionsByConnection = make(map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState) r.mu.Unlock() closeSubs(allToClose) @@ -1231,7 +1237,7 @@ func (r *Resolver) sendTriggerHeartbeats() { } type SubscriptionIdentifier struct { - ConnectionID int64 + ConnectionID ConnectionID SubscriptionID int64 } @@ -1256,7 +1262,7 @@ func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { return nil } -func (r *Resolver) UnsubscribeClient(connectionID int64) error { +func (r *Resolver) UnsubscribeClient(connectionID ConnectionID) error { res := r.removeClient(connectionID) closeSubs(res.toClose) for _, cancel := range res.cancels { @@ -1324,7 +1330,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ - ConnectionID: ConnectionIDs.Add(1), + ConnectionID: NewConnectionID(), SubscriptionID: 0, } if r.options.Debug { diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 909efbdd81..03be74c501 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -6109,7 +6109,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } for i := 1; i <= 10; i++ { - id.ConnectionID = int64(i) + id.ConnectionID = ConnectionID(i) id.SubscriptionID = int64(i) recorder.complete.Store(false) err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) From 8ea3957d94316e1885b60a34d40a431b6aff5b78 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 13:44:45 +0100 Subject: [PATCH 38/52] fix: issue with legacy ws not being pingable --- .../protocol/graphql_transport_ws.go | 9 +- .../subscriptionclient/protocol/graphql_ws.go | 14 --- .../protocol/graphql_ws_test.go | 26 ----- .../subscriptionclient/protocol/protocol.go | 8 +- .../transport/sse_transport.go | 35 ++++--- .../subscriptionclient/transport/transport.go | 14 --- .../subscriptionclient/transport/ws_conn.go | 17 +++- .../transport/ws_conn_test.go | 2 + .../transport/ws_transport_test.go | 96 +++++++++++++++++++ 9 files changed, 142 insertions(+), 79 deletions(-) delete mode 100644 v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index 23cbcf119c..521698495d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -80,7 +80,7 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay } } -// Ping implements Protocol. +// Ping implements Pinger. func (p *GraphQLTransportWS) Ping(ctx context.Context, conn *websocket.Conn) error { msg := outgoingMessage{ Type: gtwsTypePing, @@ -88,7 +88,7 @@ func (p *GraphQLTransportWS) Ping(ctx context.Context, conn *websocket.Conn) err return wsjson.Write(ctx, conn, msg) } -// Pong implements Protocol. +// Pong implements Pinger. func (p *GraphQLTransportWS) Pong(ctx context.Context, conn *websocket.Conn) error { msg := outgoingMessage{ Type: gtwsTypePong, @@ -162,4 +162,7 @@ func (p *GraphQLTransportWS) decode(raw incomingMessage) (*WireMessage, error) { return msg, nil } -var _ Protocol = (*GraphQLTransportWS)(nil) +var ( + _ Protocol = (*GraphQLTransportWS)(nil) + _ Pinger = (*GraphQLTransportWS)(nil) +) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index 854154bfd3..1cb15f516b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -101,20 +101,6 @@ func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessag return p.decode(raw) } -// Ping implements Protocol. -// Legacy protocol doesn't support client-initiated ping, this is a no-op. -func (p *GraphQLWS) Ping(ctx context.Context, conn *websocket.Conn) error { - // Legacy protocol doesn't have client ping - only server sends ka - return nil -} - -// Pong implements Protocol. -// Legacy protocol doesn't support pong messages, this is a no-op. -func (p *GraphQLWS) Pong(ctx context.Context, conn *websocket.Conn) error { - // Legacy protocol doesn't have pong - return nil -} - func (p *GraphQLWS) decode(raw incomingMessage) (*WireMessage, error) { msg := &WireMessage{ ID: raw.ID, diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go index 66c85b98f7..f19ebe095b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go @@ -312,32 +312,6 @@ func TestGraphQLWSLegacy_Read(t *testing.T) { }) } -func TestGraphQLWSLegacy_PingPong(t *testing.T) { - t.Parallel() - - t.Run("ping is a no-op for legacy protocol", func(t *testing.T) { - t.Parallel() - - // Legacy protocol doesn't support client-initiated ping - p := NewGraphQLWS() - - // This should not error, just be a no-op - err := p.Ping(context.Background(), nil) - require.NoError(t, err) - }) - - t.Run("pong is a no-op for legacy protocol", func(t *testing.T) { - t.Parallel() - - // Legacy protocol doesn't support pong - p := NewGraphQLWS() - - // This should not error, just be a no-op - err := p.Pong(context.Background(), nil) - require.NoError(t, err) - }) -} - func newGWSTestServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { t.Helper() diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go index 3b5378e505..f4941c313a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -22,11 +22,13 @@ type Protocol interface { // Read blocks until the next message arrives and decodes it. Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) +} - // Ping requests a liveness check from the server. No-op for protocols that don't support it. +// Pinger is an optional interface for protocols that support client-initiated +// ping/pong (e.g. graphql-transport-ws). Protocols that only have server-initiated +// keep-alive (e.g. legacy graphql-ws with ka messages) do not implement this. +type Pinger interface { Ping(ctx context.Context, conn *websocket.Conn) error - - // Pong responds to a server liveness check. No-op for protocols that don't support it. Pong(ctx context.Context, conn *websocket.Conn) error } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 9b5630630e..2cdad60351 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -73,20 +73,11 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts abstractlogger.String("method", string(method)), ) - // Derive a request context that outlives ctx (via WithoutCancel) so we can - // control its lifetime independently. Two AfterFunc registrations tie the - // request to both shutdown paths: - // - t.ctx cancel: transport-wide shutdown, tears down all in-flight requests. - // - ctx cancel: individual subscription cancelled by the caller. - requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) - context.AfterFunc(t.ctx, requestCancel) - context.AfterFunc(ctx, requestCancel) - switch method { case common.SSEMethodPOST: - httpReq, err = buildPOSTRequest(requestCtx, req, opts) + httpReq, err = buildPOSTRequest(req, opts) case common.SSEMethodGET: - httpReq, err = buildGETRequest(requestCtx, req, opts) + httpReq, err = buildGETRequest(req, opts) default: return nil, fmt.Errorf("unsupported SSE method: %s", method) } @@ -95,9 +86,21 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts return nil, err } + // Derive a request context that outlives ctx (via WithoutCancel) so we can + // control its lifetime independently. Two AfterFunc registrations tie the + // request to both shutdown paths: + // - t.ctx cancel: transport-wide shutdown, tears down all in-flight requests. + // - ctx cancel: individual subscription cancelled by the caller. + requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) + context.AfterFunc(t.ctx, requestCancel) + context.AfterFunc(ctx, requestCancel) + + httpReq = httpReq.WithContext(requestCtx) + // Execute request resp, err := t.client.Do(httpReq) if err != nil { + requestCancel() t.log.Error("sseTransport.Subscribe", abstractlogger.String("endpoint", opts.Endpoint), abstractlogger.Error(err), @@ -106,6 +109,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts } if resp.StatusCode != http.StatusOK { + requestCancel() body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) resp.Body.Close() t.log.Error("sseTransport.Subscribe", @@ -120,6 +124,7 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts // Verify content type (should be text/event-stream) if err := t.validateContentType(resp); err != nil { + requestCancel() resp.Body.Close() return nil, err } @@ -148,13 +153,13 @@ func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts } // buildPOSTRequest creates a POST request with JSON body (graphql-sse spec). -func buildPOSTRequest(ctx context.Context, req *common.Request, opts common.Options) (*http.Request, error) { +func buildPOSTRequest(req *common.Request, opts common.Options) (*http.Request, error) { body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, opts.Endpoint, bytes.NewReader(body)) + httpReq, err := http.NewRequest(http.MethodPost, opts.Endpoint, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -170,7 +175,7 @@ func buildPOSTRequest(ctx context.Context, req *common.Request, opts common.Opti } // buildGETRequest creates a GET request with query parameters (traditional SSE). -func buildGETRequest(ctx context.Context, req *common.Request, opts common.Options) (*http.Request, error) { +func buildGETRequest(req *common.Request, opts common.Options) (*http.Request, error) { // Parse the endpoint URL u, err := url.Parse(opts.Endpoint) if err != nil { @@ -203,7 +208,7 @@ func buildGETRequest(ctx context.Context, req *common.Request, opts common.Optio u.RawQuery = q.Encode() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + httpReq, err := http.NewRequest(http.MethodGet, u.String(), nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go deleted file mode 100644 index a73cd376c1..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport.go +++ /dev/null @@ -1,14 +0,0 @@ -package transport - -import ( - "context" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" -) - -// Transport defines the interface for subscription transports. -// A transport is responsible for managing the full connection to the upstream server. -type Transport interface { - Subscribe(ctx context.Context, req *common.Request, opts common.Options) (results <-chan *common.Message, cancel func(), err error) - Close() error -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index b3e12542e7..4eb964daf0 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -182,9 +182,11 @@ func (c *wsConnection) readLoop() { switch msg.Type { case protocol.MessagePing: c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "ping")) - pongCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) - _ = c.protocol.Pong(pongCtx, c.conn) - cancel() + if pinger, ok := c.protocol.(protocol.Pinger); ok { + pongCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) + _ = pinger.Pong(pongCtx, c.conn) + cancel() + } case protocol.MessagePong: c.lastPongAt.Store(time.Now().UnixNano()) c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "pong")) @@ -256,11 +258,18 @@ func (c *wsConnection) subCount() int { } // sendPing sends a protocol-level ping message and records the timestamp. +// For protocols that don't implement Pinger (e.g. legacy graphql-ws), +// this is a no-op — lastPingSentAt stays zero so pongOverdue never triggers. func (c *wsConnection) sendPing() error { + pinger, ok := c.protocol.(protocol.Pinger) + if !ok { + return nil + } + pingCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) defer cancel() - err := c.protocol.Ping(pingCtx, c.conn) + err := pinger.Ping(pingCtx, c.conn) if err != nil { return err } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 3cd213b39d..152944680b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -732,6 +732,8 @@ func (m *mockProtocol) Read(ctx context.Context, conn *websocket.Conn) (*protoco } } +// Ping and Pong implement protocol.Pinger — the mock simulates graphql-transport-ws. + func (m *mockProtocol) Ping(ctx context.Context, conn *websocket.Conn) error { return nil } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index edf588c16f..01bd0b92f3 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -1058,6 +1058,102 @@ func TestWSTransport_Heartbeat(t *testing.T) { time.Sleep(250 * time.Millisecond) assert.Equal(t, 1, tr.ConnCount()) }) + + t.Run("legacy graphql-ws survives ping timeout with ka messages", func(t *testing.T) { + t.Parallel() + + // Server sends periodic ka (keep-alive) messages, never expects client pings. + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read in the background so the close handshake completes promptly. + closed := make(chan struct{}) + go func() { + defer close(closed) + var discard map[string]any + _ = wsjson.Read(ctx, conn, &discard) + }() + + for { + select { + case <-closed: + return + case <-ctx.Done(): + return + case <-time.After(30 * time.Millisecond): + if err := wsjson.Write(ctx, conn, map[string]string{"type": "ka"}); err != nil { + return + } + } + } + }) + + // Enable ping loop with a tight timeout. Legacy connections are + // unaffected because sendPing is a no-op for non-Pinger protocols, + // so lastPingSentAt stays zero and pongOverdue never triggers. + tr := NewWSTransport(t.Context(), WSTransportOptions{ + PingInterval: 50 * time.Millisecond, + PingTimeout: 150 * time.Millisecond, + WriteTimeout: 100 * time.Millisecond, + }) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Survive well past the ping timeout — several cycles. + time.Sleep(400 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("legacy graphql-ws does not send client pings", func(t *testing.T) { + t.Parallel() + + // Track any messages the server receives after the subscribe. + var extraMessages atomic.Int32 + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Any further messages from the client are unexpected — legacy + // clients should never send ping. + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + extraMessages.Add(1) + } + }) + + tr := NewWSTransport(t.Context(), WSTransportOptions{ + PingInterval: 50 * time.Millisecond, + }) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Wait long enough for several ping cycles to pass. + time.Sleep(200 * time.Millisecond) + + // Server should not have received any messages (no pings sent). + assert.Equal(t, int32(0), extraMessages.Load()) + }) } func TestWSTransport_Defaults(t *testing.T) { From 8367e22f42c744e59f13635ae56b86275ba2783a Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 15:52:12 +0100 Subject: [PATCH 39/52] use default error extension config --- .../graphql_subscription_client.go | 36 +++++++++++++------ .../graphql_subscription_client_test.go | 10 +++--- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 62e082c6aa..a44c83ddee 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -27,6 +27,11 @@ type SubscriptionClientConfig struct { AckTimeout time.Duration WriteTimeout time.Duration ReadLimit int64 + + // DefaultErrorExtensionCode is the extension code attached to GraphQL + // errors produced by upstream connection failures. Should match the + // resolve package's setting for consistent error formatting. + DefaultErrorExtensionCode string } func defaultSubscriptionClientConfig() *SubscriptionClientConfig { @@ -106,6 +111,14 @@ func WithWriteTimeout(d time.Duration) SubscriptionClientOption { } } +// WithDefaultErrorExtensionCode sets the extension code attached to GraphQL +// errors produced by upstream connection failures. +func WithDefaultErrorExtensionCode(code string) SubscriptionClientOption { + return func(cfg *SubscriptionClientConfig) { + cfg.DefaultErrorExtensionCode = code + } +} + // WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. // Default: 1MB. func WithReadLimit(n int64) SubscriptionClientOption { @@ -117,7 +130,8 @@ func WithReadLimit(n int64) SubscriptionClientOption { // subscriptionClientV2 implements GraphQLSubscriptionClient using the new // channel-based subscription client. type subscriptionClientV2 struct { - client *client.Client + client *client.Client + defaultErrorExtensionCode string } // NewGraphQLSubscriptionClient creates a new subscription client. @@ -128,6 +142,7 @@ func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClien } return &subscriptionClientV2{ + defaultErrorExtensionCode: cfg.DefaultErrorExtensionCode, client: client.New(ctx, client.Config{ UpgradeClient: cfg.UpgradeClient, StreamingClient: cfg.StreamingClient, @@ -148,12 +163,12 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu return err } - handler := buildMessageHandler(updater) + handler := buildMessageHandler(updater, c.defaultErrorExtensionCode) cancel, err := c.client.Subscribe(ctx.Context(), req, opts, handler) if err != nil { if isUpstreamError(err) { - updater.Error(formatUpstreamServiceError(err)) + updater.Error(formatUpstreamServiceError(err, c.defaultErrorExtensionCode)) updater.Done() return nil } @@ -169,11 +184,11 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu } // buildMessageHandler creates the handler that bridges client.Message → resolve.SubscriptionUpdater. -func buildMessageHandler(updater resolve.SubscriptionUpdater) client.Handler { +func buildMessageHandler(updater resolve.SubscriptionUpdater, errorCode string) client.Handler { return func(msg *client.Message) { switch msg.Type { case client.MessageTypeConnectionError: - updater.Error(formatUpstreamServiceError(msg.Err)) + updater.Error(formatUpstreamServiceError(msg.Err, errorCode)) updater.Done() case client.MessageTypeError: data, err := json.Marshal(msg.Payload) @@ -200,7 +215,7 @@ func buildMessageHandler(updater resolve.SubscriptionUpdater) client.Handler { } // isUpstreamError reports whether err is a connection-level upstream error -// that should be reported to the client as an UPSTREAM_SERVICE_ERROR. +// that should be surfaced as a GraphQL error to the client. // ErrFailedUpgrade and ErrInvalidSubprotocol are intentionally excluded so // they propagate to the router, which formats detailed error messages // (e.g. including the subgraph name and HTTP status code). @@ -267,10 +282,9 @@ func mapWSSubprotocol(proto string) client.WSSubprotocol { } // formatUpstreamServiceError formats a connection-level error as a GraphQL error -// response with the UPSTREAM_SERVICE_ERROR extension code. If the error chain -// contains a WebSocket close error, the close code and reason are included in -// extensions. -func formatUpstreamServiceError(err error) []byte { +// response with the configured error extension code. If the error chain contains +// a WebSocket close error, the close code and reason are included in extensions. +func formatUpstreamServiceError(err error, code string) []byte { type errorExtensions struct { Code string `json:"code"` CloseCode int `json:"closeCode,omitempty"` @@ -284,7 +298,7 @@ func formatUpstreamServiceError(err error) []byte { gqlErr := graphqlError{ Message: "upstream service error", - Extensions: errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"}, + Extensions: errorExtensions{Code: code}, } var closeErr websocket.CloseError diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index c3066296e4..6ee1b013d0 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -47,20 +47,20 @@ func (t *testBridgeUpdater) Subscriptions() map[context.Context]resolve.Subscrip func TestBuildMessageHandlerRoutesEachMessageTypeCorrectly(t *testing.T) { t.Run("error is upstream service error for connection error", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildMessageHandler(updater) + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") handler(&client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed}) require.True(t, updater.done) require.Len(t, updater.errors, 1) - assert.Contains(t, string(updater.errors[0]), "UPSTREAM_SERVICE_ERROR") + assert.Contains(t, string(updater.errors[0]), "DOWNSTREAM_SERVICE_ERROR") require.Empty(t, updater.updates) require.False(t, updater.completed) }) t.Run("error contains payload for graphql error", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildMessageHandler(updater) + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") handler(&client.Message{ Type: client.MessageTypeError, @@ -78,7 +78,7 @@ func TestBuildMessageHandlerRoutesEachMessageTypeCorrectly(t *testing.T) { t.Run("update is delivered without completing for data message", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildMessageHandler(updater) + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") handler(&client.Message{ Type: client.MessageTypeData, @@ -96,7 +96,7 @@ func TestBuildMessageHandlerRoutesEachMessageTypeCorrectly(t *testing.T) { t.Run("complete and done are set for complete message", func(t *testing.T) { updater := &testBridgeUpdater{} - handler := buildMessageHandler(updater) + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") handler(&client.Message{Type: client.MessageTypeComplete}) From 8d98128b8f4348a80d4ccd272416a4b1d7e15685 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 16:02:39 +0100 Subject: [PATCH 40/52] refactor: rename connectionIDs to connectionIDCounter --- v2/pkg/engine/resolve/resolve.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 1ac2688983..002dd2a8ff 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -30,12 +30,12 @@ const ( // ConnectionID identifies a client connection for subscription routing. type ConnectionID int64 -// connectionIDs is the monotonic counter backing NewConnectionID. -var connectionIDs atomic.Int64 +// connectionIDCounter is the monotonic counter backing NewConnectionID. +var connectionIDCounter atomic.Int64 // NewConnectionID returns a unique ConnectionID via an atomic increment. func NewConnectionID() ConnectionID { - return ConnectionID(connectionIDs.Add(1)) + return ConnectionID(connectionIDCounter.Add(1)) } type Reporter interface { From 7725628f0ce3415a70317b14883632b5b6419f11 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 16:26:22 +0100 Subject: [PATCH 41/52] refactor: unexport protocol types, remove trivial getter --- .../protocol/graphql_transport_ws.go | 26 +++++++++---------- .../subscriptionclient/protocol/graphql_ws.go | 20 +++++++------- .../subscriptionclient/transport/ws_conn.go | 5 ---- .../transport/ws_conn_test.go | 8 +++--- 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go index 521698495d..330dd969a3 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -12,9 +12,9 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) -// GraphQLTransportWS implements the graphql-transport-ws protocol. +// graphqlTransportWS implements the graphql-transport-ws protocol. // See: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md -type GraphQLTransportWS struct{} +type graphqlTransportWS struct{} const ( gtwsTypeConnectionInit = "connection_init" @@ -39,12 +39,12 @@ type incomingMessage struct { Payload json.RawMessage `json:"payload,omitempty"` } -func NewGraphQLTransportWS() *GraphQLTransportWS { - return &GraphQLTransportWS{} +func NewGraphQLTransportWS() *graphqlTransportWS { + return &graphqlTransportWS{} } // Init implements Protocol. -func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { +func (p *graphqlTransportWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { initMsg := outgoingMessage{ Type: gtwsTypeConnectionInit, } @@ -81,7 +81,7 @@ func (p *GraphQLTransportWS) Init(ctx context.Context, conn *websocket.Conn, pay } // Ping implements Pinger. -func (p *GraphQLTransportWS) Ping(ctx context.Context, conn *websocket.Conn) error { +func (p *graphqlTransportWS) Ping(ctx context.Context, conn *websocket.Conn) error { msg := outgoingMessage{ Type: gtwsTypePing, } @@ -89,7 +89,7 @@ func (p *GraphQLTransportWS) Ping(ctx context.Context, conn *websocket.Conn) err } // Pong implements Pinger. -func (p *GraphQLTransportWS) Pong(ctx context.Context, conn *websocket.Conn) error { +func (p *graphqlTransportWS) Pong(ctx context.Context, conn *websocket.Conn) error { msg := outgoingMessage{ Type: gtwsTypePong, } @@ -97,7 +97,7 @@ func (p *GraphQLTransportWS) Pong(ctx context.Context, conn *websocket.Conn) err } // Read implements Protocol. -func (p *GraphQLTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { +func (p *graphqlTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { var raw incomingMessage if err := wsjson.Read(ctx, conn, &raw); err != nil { return nil, fmt.Errorf("read message: %w", err) @@ -107,7 +107,7 @@ func (p *GraphQLTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*W } // Subscribe implements Protocol. -func (p *GraphQLTransportWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { +func (p *graphqlTransportWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { msg := outgoingMessage{ ID: id, Type: gtwsTypeSubscribe, @@ -117,7 +117,7 @@ func (p *GraphQLTransportWS) Subscribe(ctx context.Context, conn *websocket.Conn } // Unsubscribe implements Protocol. -func (p *GraphQLTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { +func (p *graphqlTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { msg := outgoingMessage{ ID: id, Type: gtwsTypeComplete, @@ -125,7 +125,7 @@ func (p *GraphQLTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Co return wsjson.Write(ctx, conn, msg) } -func (p *GraphQLTransportWS) decode(raw incomingMessage) (*WireMessage, error) { +func (p *graphqlTransportWS) decode(raw incomingMessage) (*WireMessage, error) { msg := &WireMessage{ ID: raw.ID, } @@ -163,6 +163,6 @@ func (p *GraphQLTransportWS) decode(raw incomingMessage) (*WireMessage, error) { } var ( - _ Protocol = (*GraphQLTransportWS)(nil) - _ Pinger = (*GraphQLTransportWS)(nil) + _ Protocol = (*graphqlTransportWS)(nil) + _ Pinger = (*graphqlTransportWS)(nil) ) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go index 1cb15f516b..ee9f24e38b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -12,9 +12,9 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" ) -// GraphQLWS implements the legacy graphql-ws protocol. +// graphqlWS implements the legacy graphql-ws protocol. // See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md -type GraphQLWS struct{} +type graphqlWS struct{} const ( gwsTypeConnectionInit = "connection_init" @@ -28,12 +28,12 @@ const ( gwsTypeStop = "stop" ) -func NewGraphQLWS() *GraphQLWS { - return &GraphQLWS{} +func NewGraphQLWS() *graphqlWS { + return &graphqlWS{} } // Init implements Protocol. -func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { +func (p *graphqlWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { initMsg := outgoingMessage{ Type: gwsTypeConnectionInit, } @@ -73,7 +73,7 @@ func (p *GraphQLWS) Init(ctx context.Context, conn *websocket.Conn, payload map[ } // Subscribe implements Protocol. -func (p *GraphQLWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { +func (p *graphqlWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { msg := outgoingMessage{ ID: id, Type: gwsTypeStart, @@ -83,7 +83,7 @@ func (p *GraphQLWS) Subscribe(ctx context.Context, conn *websocket.Conn, id stri } // Unsubscribe implements Protocol. -func (p *GraphQLWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { +func (p *graphqlWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { msg := outgoingMessage{ ID: id, Type: gwsTypeStop, @@ -92,7 +92,7 @@ func (p *GraphQLWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id st } // Read implements Protocol. -func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { +func (p *graphqlWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { var raw incomingMessage if err := wsjson.Read(ctx, conn, &raw); err != nil { return nil, fmt.Errorf("read message: %w", err) @@ -101,7 +101,7 @@ func (p *GraphQLWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessag return p.decode(raw) } -func (p *GraphQLWS) decode(raw incomingMessage) (*WireMessage, error) { +func (p *graphqlWS) decode(raw incomingMessage) (*WireMessage, error) { msg := &WireMessage{ ID: raw.ID, } @@ -145,4 +145,4 @@ func (p *GraphQLWS) decode(raw incomingMessage) (*WireMessage, error) { return msg, nil } -var _ Protocol = (*GraphQLWS)(nil) +var _ Protocol = (*graphqlWS)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 4eb964daf0..27ade678bd 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -246,11 +246,6 @@ func (c *wsConnection) closeConn() { c.shutdown(common.ErrConnectionClosed) } -// writeTimeoutDuration returns the configured write timeout. -func (c *wsConnection) writeTimeoutDuration() time.Duration { - return c.writeTimeout -} - func (c *wsConnection) subCount() int { c.subsMu.RLock() defer c.subsMu.RUnlock() diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 152944680b..827ac0451e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -594,7 +594,7 @@ func TestWSConnection_Defaults(t *testing.T) { proto := newMockProtocol() wsc := newWSConnection(conn, proto, wsConnectionOptions{}) - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) }) t.Run("applies default write timeout for zero value", func(t *testing.T) { @@ -606,7 +606,7 @@ func TestWSConnection_Defaults(t *testing.T) { writeTimeout: 0, }) - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) }) t.Run("overrides write timeout when provided", func(t *testing.T) { @@ -618,7 +618,7 @@ func TestWSConnection_Defaults(t *testing.T) { writeTimeout: 10 * time.Second, }) - assert.Equal(t, 10*time.Second, wsc.writeTimeoutDuration()) + assert.Equal(t, 10*time.Second, wsc.writeTimeout) }) t.Run("ignores negative write timeout", func(t *testing.T) { @@ -630,7 +630,7 @@ func TestWSConnection_Defaults(t *testing.T) { writeTimeout: -1 * time.Second, }) - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeoutDuration()) + assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) }) } From a198549c93d4da0d11662dbf0b6eac7f00938d58 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 16:34:24 +0100 Subject: [PATCH 42/52] docs: fix stale comments on handleTriggerUpdate --- v2/pkg/engine/resolve/resolve.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 002dd2a8ff..fc6ed47cee 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -1080,9 +1080,7 @@ type pendingFilterError struct { sub *subscriptionState } -// handleTriggerUpdate sends data to all subscriptions of a trigger using snapshot-and-release. -// The lock is released before performing I/O to avoid deadlocks when executeSubscriptionUpdate -// calls AsyncUnsubscribeSubscription on flush failure. +// handleTriggerUpdate sends data to all subscriptions of a trigger. func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { trig, ok := r.getTrigger(id) if !ok { @@ -1110,7 +1108,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { wg.Wait() } -// handleUpdateSubscription sends data to a single subscription using snapshot-and-release. +// handleUpdateSubscription sends data to a single subscription. func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { trig, ok := r.getTrigger(id) if !ok { From ca18a45cd2a091667ceda451aed6b90875878547 Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 17:01:45 +0100 Subject: [PATCH 43/52] fix: broadcast Source.Start() error to all trigger subscribers --- v2/pkg/engine/resolve/resolve.go | 6 +- v2/pkg/engine/resolve/resolve_test.go | 81 +++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index fc6ed47cee..b8b6f3d440 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -849,7 +849,9 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) } - s.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) + for _, sub := range trig.snapshotSubscriptions() { + sub.writeError(r.errorFormatter, sub.ctx, err, sub.resolve.Response) + } r.doneTriggerFromUpdater(triggerID) return } @@ -865,8 +867,8 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error func (r *Resolver) getTrigger(id uint64) (*trigger, bool) { r.mu.Lock() - defer r.mu.Unlock() trig, ok := r.triggers[id] + r.mu.Unlock() return trig, ok } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 03be74c501..92e3f10647 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -8046,6 +8046,87 @@ func Benchmark_NestedBatchingArena(b *testing.B) { }) } +// startFailStream blocks until subBReady is closed, then returns an error from Start. +type startFailStream struct { + subBReady chan struct{} +} + +func (s *startFailStream) Start(_ *Context, _ http.Header, _ []byte, _ SubscriptionUpdater) error { + <-s.subBReady + return errors.New("connection refused") +} + +func TestSourceStartFailure(t *testing.T) { + t.Run("broadcasts error to all subscribers", func(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + subBReady := make(chan struct{}) + + stream := &startFailStream{subBReady: subBReady} + + plan := &GraphQLSubscription{ + Trigger: GraphQLSubscriptionTrigger{ + Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("counter"), + Value: &Integer{ + Path: []string{"counter"}, + }, + }, + }, + }, + }, + } + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1024, + AsyncErrorWriter: &TestErrorWriter{}, + SubscriptionHeartbeatInterval: time.Hour, + }) + + writerA := &SubscriptionRecorder{buf: &bytes.Buffer{}} + writerB := &SubscriptionRecorder{buf: &bytes.Buffer{}} + + idA := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 1} + idB := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 2} + + ctxA := NewContext(context.Background()) + ctxB := NewContext(context.Background()) + + // Subscribe A — creates the trigger. + err := resolver.AsyncResolveGraphQLSubscription(ctxA, plan, writerA, idA) + require.NoError(t, err) + + // Subscribe B — joins the existing trigger. + err = resolver.AsyncResolveGraphQLSubscription(ctxB, plan, writerB, idB) + require.NoError(t, err) + + // Unblock Source.Start() so it returns an error. + close(subBReady) + + timeout := 2 * time.Second + writerA.AwaitAnyMessageCount(t, timeout) + writerB.AwaitAnyMessageCount(t, timeout) + }) +} + func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() From af9077120fb6f980ed4c5adb1deb2d2b62c6478e Mon Sep 17 00:00:00 2001 From: endigma Date: Thu, 16 Apr 2026 17:36:30 +0100 Subject: [PATCH 44/52] fix: broadcast trigger startup errors to all subscribers --- v2/pkg/engine/resolve/resolve.go | 50 ++++++------ v2/pkg/engine/resolve/resolve_test.go | 112 ++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 27 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b8b6f3d440..bea484483f 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -717,27 +717,22 @@ type StartupHookContext struct { Updater func(data []byte) } -func (r *Resolver) executeStartupHooks(add *addSubscription, sub *subscriptionState, updater *subscriptionUpdater) error { +func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscriptionUpdater) error { hook, ok := add.resolve.Trigger.Source.(HookableSubscriptionDataSource) - if ok { - hookCtx := StartupHookContext{ - Context: add.ctx.Context(), - Updater: func(data []byte) { - updater.UpdateSubscription(add.id, data) - }, - } - - err := hook.SubscriptionOnStart(hookCtx, add.input) - if err != nil { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) - } - sub.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) - _ = r.UnsubscribeSubscription(add.id) - return err - } + if !ok { + return nil } - return nil + hookCtx := StartupHookContext{ + Context: add.ctx.Context(), + Updater: func(data []byte) { + updater.UpdateSubscription(add.id, data) + }, + } + err := hook.SubscriptionOnStart(hookCtx, add.input) + if err != nil && r.options.Debug { + fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) + } + return err } // registerSubscriptionLocked updates the by-ID and by-connection indexes. @@ -801,9 +796,11 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error // Register first so startup hooks can deliver initial data via UpdateSubscription. r.registerSubscriptionLocked(trig, s) // Execute the startup hooks in a goroutine to avoid holding the lock. - // On failure, executeStartupHooks calls UnsubscribeSubscription to clean up. go func() { - _ = r.executeStartupHooks(add, s, trig.updater) + if err := r.executeStartupHooks(add, trig.updater); err != nil { + s.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) + _ = r.UnsubscribeSubscription(add.id) + } }() return nil } @@ -838,13 +835,12 @@ func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error fmt.Printf("resolver:trigger:start:%d\n", triggerID) } - // This is blocking so the startup hook can decide if a subscription should be started or not by returning an error - err := r.executeStartupHooks(add, s, trig.updater) - if err != nil { - return + // The startup hook is blocking so it can reject the subscription before Source.Start. + // If either step fails, broadcast the error to all subs and tear down the trigger. + err := r.executeStartupHooks(add, trig.updater) + if err == nil { + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, trig.updater) } - - err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, trig.updater) if err != nil { if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 92e3f10647..8ce7f2066d 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -8127,6 +8127,118 @@ func TestSourceStartFailure(t *testing.T) { }) } +// hookFailStream is a SubscriptionDataSource + HookableSubscriptionDataSource. +// The hook for the subscription whose context matches failCtx blocks until +// subBRegistered is closed, then returns an error. All other hooks succeed. +type hookFailStream struct { + subBRegistered chan struct{} + failCtx context.Context + sourceStarted atomic.Bool +} + +func (s *hookFailStream) Start(_ *Context, _ http.Header, _ []byte, _ SubscriptionUpdater) error { + s.sourceStarted.Store(true) + select {} +} + +func (s *hookFailStream) SubscriptionOnStart(ctx StartupHookContext, _ []byte) error { + if ctx.Context == s.failCtx { + <-s.subBRegistered + return errors.New("startup hook failed") + } + return nil +} + +func TestStartupHookFailure(t *testing.T) { + t.Run("cleans up all subscribers when trigger creator hook fails", func(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + ctxAInner, cancelA := context.WithCancel(context.Background()) + defer cancelA() + ctxBInner, cancelB := context.WithCancel(context.Background()) + defer cancelB() + + subBRegistered := make(chan struct{}) + + stream := &hookFailStream{ + subBRegistered: subBRegistered, + failCtx: ctxAInner, + } + + plan := &GraphQLSubscription{ + Trigger: GraphQLSubscriptionTrigger{ + Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("counter"), + Value: &Integer{ + Path: []string{"counter"}, + }, + }, + }, + }, + }, + } + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1024, + AsyncErrorWriter: &TestErrorWriter{}, + SubscriptionHeartbeatInterval: time.Hour, + }) + + writerA := &SubscriptionRecorder{buf: &bytes.Buffer{}} + writerB := &SubscriptionRecorder{buf: &bytes.Buffer{}} + + idA := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 1} + idB := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 2} + + ctxA := NewContext(ctxAInner) + ctxB := NewContext(ctxBInner) + + // Subscribe A — creates the trigger. + err := resolver.AsyncResolveGraphQLSubscription(ctxA, plan, writerA, idA) + require.NoError(t, err) + + // Subscribe B — joins the existing trigger. + err = resolver.AsyncResolveGraphQLSubscription(ctxB, plan, writerB, idB) + require.NoError(t, err) + + // Unblock A's startup hook so it fails. + close(subBRegistered) + + // Sub A should receive an error. + writerA.AwaitAnyMessageCount(t, time.Second) + + // Sub B should also be cleaned up — not left orphaned on a triggerless source. + require.Eventually(t, func() bool { + writerB.mux.Lock() + hasMessages := len(writerB.messages) > 0 + writerB.mux.Unlock() + return hasMessages || writerB.complete.Load() + }, 2*time.Second, 10*time.Millisecond, + "sub B was orphaned: no error, no complete — "+ + "trigger left without a data source after first subscription's startup hook failed") + + require.False(t, stream.sourceStarted.Load(), "Source.Start() should not have been called") + }) +} + func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() From a13b074010953dcd22a02acd0cc605e33fcaa3a4 Mon Sep 17 00:00:00 2001 From: endigma Date: Mon, 20 Apr 2026 11:26:28 +0100 Subject: [PATCH 45/52] chore: Handle WS and invalid transport --- .../graphql_datasource/subscriptionclient/client.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index 6f18505408..b787381593 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -3,6 +3,7 @@ package client import ( "context" "errors" + "fmt" "net/http" "time" @@ -81,10 +82,14 @@ func (c *Client) Subscribe(ctx context.Context, req *common.Request, opts common return nil, ErrClientClosed } - if opts.Transport == common.TransportSSE { + switch opts.Transport { + case common.TransportSSE: return c.sse.Subscribe(ctx, req, opts, handler) + case common.TransportWS: + return c.ws.Subscribe(ctx, req, opts, handler) + default: + return nil, fmt.Errorf("unsupported transport: %q", opts.Transport) } - return c.ws.Subscribe(ctx, req, opts, handler) } // Stats returns client statistics. From d237e3ceb83f708c03cef4efe6d6a0e9918e8379 Mon Sep 17 00:00:00 2001 From: endigma Date: Tue, 21 Apr 2026 11:12:59 +0100 Subject: [PATCH 46/52] fix(sse): remove SSEMethodAuto --- .../subscriptionclient/common/options.go | 2 - .../subscriptionclient/exports.go | 1 - .../transport/sse_transport.go | 13 +-- .../transport/sse_transport_test.go | 93 +++++++------------ 4 files changed, 38 insertions(+), 71 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go index 9177be370f..fe0eb31c23 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go @@ -38,7 +38,6 @@ func (s WSSubprotocol) Subprotocols() []string { type SSEMethod string const ( - SSEMethodAuto SSEMethod = "" // Auto: POST for graphql-sse (default) SSEMethodPOST SSEMethod = "POST" // POST with JSON body (graphql-sse spec) SSEMethodGET SSEMethod = "GET" // GET with query parameters (traditional SSE) ) @@ -54,6 +53,5 @@ type Options struct { WSSubprotocol WSSubprotocol // Only affects the SSE transport. - // Defaults to POST (graphql-sse spec). SSEMethod SSEMethod } diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go index 51db35954e..f1ff7e47ab 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -37,7 +37,6 @@ const ( SubprotocolGraphQLTransportWS = common.SubprotocolGraphQLTransportWS SubprotocolGraphQLWS = common.SubprotocolGraphQLWS - SSEMethodAuto = common.SSEMethodAuto SSEMethodPOST = common.SSEMethodPOST SSEMethodGET = common.SSEMethodGET ) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go index 2cdad60351..4e93a3fa4d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -57,29 +57,24 @@ func NewSSETransport(ctx context.Context, client *http.Client, log abstractlogge // Each call creates a new HTTP request (no multiplexing). // // The HTTP method is determined by opts.SSEMethod: -// - SSEMethodAuto or SSEMethodPOST: POST with JSON body (graphql-sse spec) +// - SSEMethodPOST: POST with JSON body (graphql-sse spec) // - SSEMethodGET: GET with query parameters (traditional SSE) func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { var httpReq *http.Request var err error - method := opts.SSEMethod - if method == common.SSEMethodAuto { - method = common.SSEMethodPOST // Default to POST (graphql-sse spec) - } - t.log.Debug("sseTransport.Subscribe", abstractlogger.String("endpoint", opts.Endpoint), - abstractlogger.String("method", string(method)), + abstractlogger.String("method", string(opts.SSEMethod)), ) - switch method { + switch opts.SSEMethod { case common.SSEMethodPOST: httpReq, err = buildPOSTRequest(req, opts) case common.SSEMethodGET: httpReq, err = buildGETRequest(req, opts) default: - return nil, fmt.Errorf("unsupported SSE method: %s", method) + return nil, fmt.Errorf("unsupported SSE method: %s", opts.SSEMethod) } if err != nil { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go index 8cd7104b79..19100241ae 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -62,6 +62,7 @@ func TestSSETransport_Subscribe(t *testing.T) { }, common.Options{ Endpoint: server.URL, Transport: common.TransportSSE, + SSEMethod: common.SSEMethodPOST, }, handler) require.NoError(t, err) defer cancel() @@ -106,8 +107,9 @@ func TestSSETransport_Subscribe(t *testing.T) { cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ - Endpoint: server.URL, - Headers: headers, + Endpoint: server.URL, + Headers: headers, + SSEMethod: common.SSEMethodPOST, }, handler) require.NoError(t, err) defer cancel() @@ -138,7 +140,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { user { name } }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -163,7 +165,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -188,7 +190,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -223,7 +225,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -263,7 +265,7 @@ func TestSSETransport_Subscribe(t *testing.T) { wrappedHandler, collect := waitForMessages(handler) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, wrappedHandler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, wrappedHandler) require.NoError(t, err) defer cancel() @@ -295,7 +297,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) // Receive first message @@ -341,7 +343,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(transportCtx, &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -368,7 +370,7 @@ func TestSSETransport_Subscribe(t *testing.T) { _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) require.Error(t, err) assert.Contains(t, err.Error(), "401") @@ -386,7 +388,7 @@ func TestSSETransport_Subscribe(t *testing.T) { _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) require.Error(t, err) assert.Contains(t, err.Error(), "500") @@ -412,7 +414,7 @@ func TestSSETransport_Subscribe(t *testing.T) { tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - opts := common.Options{Endpoint: server.URL} + opts := common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST} handler1, receive1 := collectingHandler() cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) @@ -452,7 +454,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -487,7 +489,7 @@ func TestSSETransport_Subscribe(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -519,7 +521,7 @@ func TestSSETransport_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) tr := NewSSETransport(ctx, http.DefaultClient, nil) - opts := common.Options{Endpoint: server.URL} + opts := common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST} handler1, receive1 := collectingHandler() _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) @@ -576,7 +578,7 @@ func TestSSETransport_CustomClient(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -625,7 +627,7 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) { handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, handler) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) require.NoError(t, err) defer cancel() @@ -646,7 +648,7 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) { _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", - }, common.Options{Endpoint: server.URL}, func(_ *common.Message) {}) + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) require.Error(t, err) assert.True(t, strings.Contains(err.Error(), "content-type") || strings.Contains(err.Error(), "Content-Type")) @@ -819,64 +821,37 @@ func TestSSETransport_GETMethod(t *testing.T) { }) } -func TestSSETransport_MethodDefault(t *testing.T) { +func TestSSETransport_Subscribe_UnrecognizedMethod(t *testing.T) { t.Parallel() - t.Run("defaults to POST when SSEMethod is auto", func(t *testing.T) { + t.Run("returns error for unrecognized SSE method", func(t *testing.T) { t.Parallel() - var receivedMethod string - server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { - receivedMethod = r.Method - - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, "event: complete\ndata:\n\n") - }) - tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - handler, receive := collectingHandler() - cancel, err := tr.Subscribe(context.Background(), &common.Request{ + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ - Endpoint: server.URL, - SSEMethod: common.SSEMethodAuto, // or just omit it - }, handler) - require.NoError(t, err) - defer cancel() - - receive(t, time.Second) + Endpoint: "http://example.invalid", + SSEMethod: common.SSEMethod("PATCH"), + }, func(_ *common.Message) {}) - assert.Equal(t, http.MethodPost, receivedMethod) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported SSE method") }) - t.Run("explicit POST method works", func(t *testing.T) { + t.Run("returns error when SSE method is empty", func(t *testing.T) { t.Parallel() - var receivedMethod string - server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { - receivedMethod = r.Method - - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, "event: complete\ndata:\n\n") - }) - tr := NewSSETransport(t.Context(), http.DefaultClient, nil) - handler, receive := collectingHandler() - cancel, err := tr.Subscribe(context.Background(), &common.Request{ + _, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", }, common.Options{ - Endpoint: server.URL, - SSEMethod: common.SSEMethodPOST, - }, handler) - require.NoError(t, err) - defer cancel() + Endpoint: "http://example.invalid", + }, func(_ *common.Message) {}) - receive(t, time.Second) - - assert.Equal(t, http.MethodPost, receivedMethod) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported SSE method") }) } From aef088f8d4509c6aa0770faf19e08b380169532b Mon Sep 17 00:00:00 2001 From: endigma Date: Tue, 21 Apr 2026 12:48:37 +0100 Subject: [PATCH 47/52] refactor: consolidate defaults to the client --- .../graphql_subscription_client.go | 5 +- .../subscriptionclient/client.go | 15 +++ .../transport/transport_test.go | 42 ++++++++ .../subscriptionclient/transport/ws_conn.go | 9 +- .../transport/ws_conn_test.go | 96 +++++-------------- .../transport/ws_transport.go | 20 +--- .../transport/ws_transport_test.go | 81 +++++----------- 7 files changed, 107 insertions(+), 161 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index a44c83ddee..882f9b3acf 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -88,7 +88,7 @@ func WithPingInterval(d time.Duration) SubscriptionClientOption { // WithPingTimeout sets the maximum time to wait for a pong response. // If no pong is received within this duration, the connection is considered dead. -// Default: 10s. +// Default: 10s. Set to 0 to disable the pong-timeout check. func WithPingTimeout(d time.Duration) SubscriptionClientOption { return func(cfg *SubscriptionClientConfig) { cfg.PingTimeout = d @@ -96,7 +96,6 @@ func WithPingTimeout(d time.Duration) SubscriptionClientOption { } // WithAckTimeout sets the maximum time to wait for connection_ack after connection_init. -// Default: 30s. func WithAckTimeout(d time.Duration) SubscriptionClientOption { return func(cfg *SubscriptionClientConfig) { cfg.AckTimeout = d @@ -104,7 +103,6 @@ func WithAckTimeout(d time.Duration) SubscriptionClientOption { } // WithWriteTimeout sets the timeout for WebSocket write operations (subscribe, unsubscribe, ping, pong). -// Default: 5s. func WithWriteTimeout(d time.Duration) SubscriptionClientOption { return func(cfg *SubscriptionClientConfig) { cfg.WriteTimeout = d @@ -120,7 +118,6 @@ func WithDefaultErrorExtensionCode(code string) SubscriptionClientOption { } // WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. -// Default: 1MB. func WithReadLimit(n int64) SubscriptionClientOption { return func(cfg *SubscriptionClientConfig) { cfg.ReadLimit = n diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index b787381593..164d34ca0e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -15,6 +15,12 @@ import ( var ErrClientClosed = errors.New("client closed") +const ( + defaultReadLimit = 1 << 20 // 1MiB + defaultAckTimeout = 30 * time.Second + defaultWriteTimeout = 5 * time.Second +) + type Client struct { ctx context.Context log abstractlogger.Logger @@ -53,6 +59,15 @@ func New(ctx context.Context, cfg Config) *Client { if cfg.Logger == nil { cfg.Logger = abstractlogger.NoopLogger } + if cfg.ReadLimit <= 0 { + cfg.ReadLimit = defaultReadLimit + } + if cfg.AckTimeout <= 0 { + cfg.AckTimeout = defaultAckTimeout + } + if cfg.WriteTimeout <= 0 { + cfg.WriteTimeout = defaultWriteTimeout + } c := &Client{ ctx: ctx, diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go index f336dfe522..afb47bb390 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go @@ -1,13 +1,55 @@ package transport import ( + "net/http" "sync" "testing" "time" + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" ) +// newTestWSConnection mirrors the defaults applied by client.New so that tests +// can continue to pass wsConnectionOptions{...} literals without repeating +// default values at every call site. +func newTestWSConnection(t *testing.T, conn *websocket.Conn, proto protocol.Protocol, opts wsConnectionOptions) *wsConnection { + t.Helper() + if opts.logger == nil { + opts.logger = abstractlogger.NoopLogger + } + if opts.writeTimeout <= 0 { + opts.writeTimeout = 5 * time.Second + } + return newWSConnection(conn, proto, opts) +} + +// newTestWSTransport mirrors the defaults applied by client.New so that tests +// can continue to pass WSTransportOptions{...} literals without repeating +// default values at every call site. +func newTestWSTransport(t *testing.T, opts WSTransportOptions) *WSTransport { + t.Helper() + if opts.UpgradeClient == nil { + opts.UpgradeClient = http.DefaultClient + } + if opts.Logger == nil { + opts.Logger = abstractlogger.NoopLogger + } + if opts.ReadLimit <= 0 { + opts.ReadLimit = 1 << 20 + } + if opts.AckTimeout <= 0 { + opts.AckTimeout = 30 * time.Second + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 5 * time.Second + } + return NewWSTransport(t.Context(), opts) +} + // collectingHandler returns a handler that appends messages to a channel, // plus a helper to receive with timeout (for use in tests). func collectingHandler() (common.Handler, func(t *testing.T, timeout time.Duration) *common.Message) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 27ade678bd..9ce29b20f3 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -15,11 +15,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" ) -var ( - ErrSubscriptionExists = errors.New("subscription ID already exists") - - defaultWriteTimeout = 5 * time.Second -) +var ErrSubscriptionExists = errors.New("subscription ID already exists") type wsConnectionOptions struct { logger abstractlogger.Logger @@ -58,9 +54,6 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConne if opts.logger == nil { opts.logger = abstractlogger.NoopLogger } - if opts.writeTimeout <= 0 { - opts.writeTimeout = defaultWriteTimeout - } ctx, cancel := context.WithCancel(context.Background()) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go index 827ac0451e..431180d125 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -27,7 +27,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) handler, _ := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ @@ -44,7 +44,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -60,7 +60,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) wsc.closeConn() _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) @@ -89,7 +89,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) // Step 1: Hold subsMu so subscribe() blocks after its closed check. wsc.subsMu.Lock() @@ -133,7 +133,7 @@ func TestWSConnection_Subscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.subscribeErr = assert.AnError - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) @@ -150,7 +150,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -175,7 +175,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -197,7 +197,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) go wsc.readLoop() @@ -213,7 +213,7 @@ func TestWSConnection_ReadLoop(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) handler, receive := collectingHandler() cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) @@ -247,7 +247,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -264,7 +264,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) require.NoError(t, err) @@ -282,7 +282,7 @@ func TestWSConnection_Unsubscribe(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.unsubscribeDelay = 500 * time.Millisecond - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ writeTimeout: 50 * time.Millisecond, }) @@ -307,7 +307,7 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ onEmpty: func() { emptyCalled <- struct{}{} }, }) @@ -331,7 +331,7 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ onEmpty: func() { emptyCalled <- struct{}{} }, }) @@ -364,7 +364,7 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ onEmpty: func() { emptyCalled <- struct{}{} }, }) @@ -385,7 +385,7 @@ func TestWSConnection_OnEmpty(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ onEmpty: func() { emptyCalled <- struct{}{} }, }) @@ -412,7 +412,7 @@ func TestWSConnection_IdleTimeout(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ idleTimeout: 200 * time.Millisecond, }) @@ -434,7 +434,7 @@ func TestWSConnection_IdleTimeout(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ idleTimeout: 200 * time.Millisecond, }) @@ -458,7 +458,7 @@ func TestWSConnection_IdleTimeout(t *testing.T) { proto := newMockProtocol() emptyCalled := make(chan struct{}, 1) - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ onEmpty: func() { emptyCalled <- struct{}{} }, }) @@ -487,7 +487,7 @@ func TestWSConnection_Close(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) handler1, receive1 := collectingHandler() _, _ = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler1) @@ -511,7 +511,7 @@ func TestWSConnection_Close(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) assert.NotPanics(t, func() { wsc.closeConn() @@ -529,7 +529,7 @@ func TestWSConnection_SubCount(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) assert.Equal(t, 0, wsc.subCount()) @@ -556,7 +556,7 @@ func TestWSConnection_WriteTimeout(t *testing.T) { conn, _ := newTestConn(t) proto := newMockProtocol() proto.pongDelay = 500 * time.Millisecond - wsc := newWSConnection(conn, proto, wsConnectionOptions{ + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ writeTimeout: 50 * time.Millisecond, }) @@ -584,56 +584,6 @@ func TestWSConnection_WriteTimeout(t *testing.T) { }) } -func TestWSConnection_Defaults(t *testing.T) { - t.Parallel() - - t.Run("applies default write timeout when omitted", func(t *testing.T) { - t.Parallel() - - conn, _ := newTestConn(t) - proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{}) - - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) - }) - - t.Run("applies default write timeout for zero value", func(t *testing.T) { - t.Parallel() - - conn, _ := newTestConn(t) - proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{ - writeTimeout: 0, - }) - - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) - }) - - t.Run("overrides write timeout when provided", func(t *testing.T) { - t.Parallel() - - conn, _ := newTestConn(t) - proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{ - writeTimeout: 10 * time.Second, - }) - - assert.Equal(t, 10*time.Second, wsc.writeTimeout) - }) - - t.Run("ignores negative write timeout", func(t *testing.T) { - t.Parallel() - - conn, _ := newTestConn(t) - proto := newMockProtocol() - wsc := newWSConnection(conn, proto, wsConnectionOptions{ - writeTimeout: -1 * time.Second, - }) - - assert.Equal(t, defaultWriteTimeout, wsc.writeTimeout) - }) -} - // Test helpers func newTestConn(t *testing.T) (*websocket.Conn, *websocket.Conn) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 1a3fbb4863..3cbc9e0ee1 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -20,12 +20,7 @@ import ( // ErrDialFailed indicates that the WebSocket dial (TCP + HTTP upgrade) failed. // The underlying cause is available via errors.Unwrap. -var ( - ErrDialFailed = errors.New("websocket dial failed") - - defaultReadLimit = int64(1024 * 1024) // 1MB - defaultAckTimeout = 30 * time.Second -) +var ErrDialFailed = errors.New("websocket dial failed") // ErrInitFailed indicates that the GraphQL protocol init (connection_init / // connection_ack handshake) failed after a successful WebSocket dial. The @@ -56,7 +51,7 @@ type WSTransportOptions struct { Logger abstractlogger.Logger // ReadLimit is the maximum message size in bytes the WebSocket connection - // will accept. Default: 1MB. + // will accept. ReadLimit int64 // PingInterval is how often the transport sends a ping to each connection. @@ -69,11 +64,10 @@ type WSTransportOptions struct { // AckTimeout is the maximum time to wait for a connection_ack during the // protocol init handshake. Passed to the protocol at construction. - // Default: 30s. AckTimeout time.Duration // WriteTimeout is the deadline applied to each WebSocket write (subscribe, - // unsubscribe, ping, pong). Passed to each connection. Default: 5s. + // unsubscribe, ping, pong). Passed to each connection. WriteTimeout time.Duration // IdleTimeout is the duration a connection stays open after its last @@ -113,14 +107,6 @@ func NewWSTransport(ctx context.Context, opts WSTransportOptions) *WSTransport { opts.Logger = abstractlogger.NoopLogger } - if opts.ReadLimit <= 0 { - opts.ReadLimit = defaultReadLimit - } - - if opts.AckTimeout <= 0 { - opts.AckTimeout = defaultAckTimeout - } - t := &WSTransport{ ctx: ctx, opts: opts, diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go index 01bd0b92f3..57f7405c01 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -43,7 +43,7 @@ func TestWSTransport_Subscribe(t *testing.T) { }) }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -85,7 +85,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -134,7 +134,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) headers1 := http.Header{"Authorization": []string{"Bearer token1"}} headers2 := http.Header{"Authorization": []string{"Bearer token2"}} @@ -188,7 +188,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler1, receive1 := collectingHandler() cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ @@ -228,7 +228,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -277,7 +277,7 @@ func TestWSTransport_Subscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) opts := common.Options{ Endpoint: server.URL, @@ -322,7 +322,7 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ Endpoint: server.URL, @@ -351,7 +351,7 @@ func TestWSTransport_SubscriberDrain(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) opts := common.Options{Endpoint: server.URL, Transport: common.TransportWS} @@ -401,7 +401,7 @@ func TestWSTransport_ConcurrentSubscribe(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{ + tr := newTestWSTransport(t, WSTransportOptions{ IdleTimeout: 30 * time.Second, }) @@ -479,7 +479,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) initPayload := map[string]any{ "Authorization": "Bearer secret-token", @@ -564,7 +564,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) initPayload := map[string]any{ "token": "legacy-auth-token", @@ -643,7 +643,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -722,7 +722,7 @@ func TestWSTransport_InitPayloadForwarding(t *testing.T) { })) t.Cleanup(server.Close) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) // First subscription with user1 token handler1, receive1 := collectingHandler() @@ -797,7 +797,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -842,7 +842,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -883,7 +883,7 @@ func TestWSTransport_LegacyProtocol(t *testing.T) { }) }) - tr := NewWSTransport(t.Context(), WSTransportOptions{}) + tr := newTestWSTransport(t, WSTransportOptions{}) handler, receive := collectingHandler() cancel, err := tr.Subscribe(context.Background(), &common.Request{ @@ -926,7 +926,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond}) + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -960,7 +960,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 100 * time.Millisecond, PingTimeout: 50 * time.Millisecond}) + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 100 * time.Millisecond, PingTimeout: 50 * time.Millisecond}) handler, receive := collectingHandler() _, err := tr.Subscribe(context.Background(), &common.Request{ @@ -1003,7 +1003,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { }) // PingInterval set, PingTimeout left at zero (disabled) - tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond}) + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -1043,7 +1043,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{PingInterval: 50 * time.Millisecond, PingTimeout: 200 * time.Millisecond}) + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond, PingTimeout: 200 * time.Millisecond}) cancel, err := tr.Subscribe(context.Background(), &common.Request{ Query: "subscription { test }", @@ -1092,7 +1092,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { // Enable ping loop with a tight timeout. Legacy connections are // unaffected because sendPing is a no-op for non-Pinger protocols, // so lastPingSentAt stays zero and pongOverdue never triggers. - tr := NewWSTransport(t.Context(), WSTransportOptions{ + tr := newTestWSTransport(t, WSTransportOptions{ PingInterval: 50 * time.Millisecond, PingTimeout: 150 * time.Millisecond, WriteTimeout: 100 * time.Millisecond, @@ -1134,7 +1134,7 @@ func TestWSTransport_Heartbeat(t *testing.T) { } }) - tr := NewWSTransport(t.Context(), WSTransportOptions{ + tr := newTestWSTransport(t, WSTransportOptions{ PingInterval: 50 * time.Millisecond, }) @@ -1156,43 +1156,6 @@ func TestWSTransport_Heartbeat(t *testing.T) { }) } -func TestWSTransport_Defaults(t *testing.T) { - t.Parallel() - - t.Run("applies default read limit when omitted", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{}) - - assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) - }) - - t.Run("applies default read limit for zero value", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 0}) - - assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) - }) - - t.Run("overrides read limit when provided", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: 2 * 1024 * 1024}) - - assert.Equal(t, int64(2*1024*1024), tr.opts.ReadLimit) - }) - - t.Run("ignores negative read limit", func(t *testing.T) { - t.Parallel() - - tr := NewWSTransport(t.Context(), WSTransportOptions{ReadLimit: -1}) - - assert.Equal(t, defaultReadLimit, tr.opts.ReadLimit) - }) - -} - // Test helpers func newGraphQLWSServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { From c5b06d907c26916628fab92e1abeaee3f0f712ff Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 22 Apr 2026 15:46:06 +0100 Subject: [PATCH 48/52] refactor: unexport subscription client config --- .../graphql_subscription_client.go | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 882f9b3acf..b2b92ad28e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -15,8 +15,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// SubscriptionClientConfig holds the subscription client configuration. -type SubscriptionClientConfig struct { +// subscriptionClientConfig holds the subscription client configuration. +type subscriptionClientConfig struct { UpgradeClient *http.Client StreamingClient *http.Client Logger abstractlogger.Logger @@ -34,8 +34,8 @@ type SubscriptionClientConfig struct { DefaultErrorExtensionCode string } -func defaultSubscriptionClientConfig() *SubscriptionClientConfig { - return &SubscriptionClientConfig{ +func defaultsubscriptionClientConfig() *subscriptionClientConfig { + return &subscriptionClientConfig{ UpgradeClient: http.DefaultClient, StreamingClient: http.DefaultClient, Logger: abstractlogger.NoopLogger, @@ -46,11 +46,11 @@ func defaultSubscriptionClientConfig() *SubscriptionClientConfig { } // SubscriptionClientOption configures the subscription client. -type SubscriptionClientOption func(*SubscriptionClientConfig) +type SubscriptionClientOption func(*subscriptionClientConfig) // WithUpgradeClient sets the HTTP client used for WebSocket upgrade requests. func WithUpgradeClient(c *http.Client) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { if c != nil { cfg.UpgradeClient = c } @@ -60,7 +60,7 @@ func WithUpgradeClient(c *http.Client) SubscriptionClientOption { // WithStreamingClient sets the HTTP client used for SSE requests. // This client should have appropriate timeouts for long-lived connections. func WithStreamingClient(c *http.Client) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { if c != nil { cfg.StreamingClient = c } @@ -70,7 +70,7 @@ func WithStreamingClient(c *http.Client) SubscriptionClientOption { // WithLogger sets the logger for the client and its transports. // If not set, logging is disabled (silent operation). func WithLogger(log abstractlogger.Logger) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { if log != nil { cfg.Logger = log } @@ -81,7 +81,7 @@ func WithLogger(log abstractlogger.Logger) SubscriptionClientOption { // Only applies to graphql-transport-ws protocol (legacy graphql-ws uses server-initiated keepalive). // Default: 30s. Set to 0 to disable client-initiated pings. func WithPingInterval(d time.Duration) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.PingInterval = d } } @@ -90,21 +90,21 @@ func WithPingInterval(d time.Duration) SubscriptionClientOption { // If no pong is received within this duration, the connection is considered dead. // Default: 10s. Set to 0 to disable the pong-timeout check. func WithPingTimeout(d time.Duration) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.PingTimeout = d } } // WithAckTimeout sets the maximum time to wait for connection_ack after connection_init. func WithAckTimeout(d time.Duration) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.AckTimeout = d } } // WithWriteTimeout sets the timeout for WebSocket write operations (subscribe, unsubscribe, ping, pong). func WithWriteTimeout(d time.Duration) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.WriteTimeout = d } } @@ -112,14 +112,14 @@ func WithWriteTimeout(d time.Duration) SubscriptionClientOption { // WithDefaultErrorExtensionCode sets the extension code attached to GraphQL // errors produced by upstream connection failures. func WithDefaultErrorExtensionCode(code string) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.DefaultErrorExtensionCode = code } } // WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. func WithReadLimit(n int64) SubscriptionClientOption { - return func(cfg *SubscriptionClientConfig) { + return func(cfg *subscriptionClientConfig) { cfg.ReadLimit = n } } @@ -133,7 +133,7 @@ type subscriptionClientV2 struct { // NewGraphQLSubscriptionClient creates a new subscription client. func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClientOption) GraphQLSubscriptionClient { - cfg := defaultSubscriptionClientConfig() + cfg := defaultsubscriptionClientConfig() for _, opt := range opts { opt(cfg) } From 9c9bb1009fb9bf4989e5962dd8c7ed36fe70fbcf Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 22 Apr 2026 15:47:42 +0100 Subject: [PATCH 49/52] docs: Add comment for ErrClientClosed --- .../datasource/graphql_datasource/subscriptionclient/client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go index 164d34ca0e..6b844621ee 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -13,6 +13,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" ) +// ErrClientClosed is returned when Subscribe is called after the client's context has been canceled. var ErrClientClosed = errors.New("client closed") const ( From f98e29c4b52b10bf2d5c7a4944cf0def3604dd3a Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 22 Apr 2026 15:52:14 +0100 Subject: [PATCH 50/52] sse: simplify line splitting could possibly cause issues with weird servers that split with just \r --- .../subscriptionclient/transport/sse_conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go index 4ba2f48174..a98999bf5e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go @@ -80,8 +80,8 @@ func (c *sseConnection) parseEventBytes(msg []byte) (eventType string, data []by return "", nil } - // Split by newlines (normalize CR/LF) - for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { + for line := range bytes.Lines(msg) { + line = bytes.TrimRight(line, "\r\n") switch { case bytes.HasPrefix(line, headerEvent): eventType = string(trimHeader(len(headerEvent), line)) From 7e7d60cfa46e7c726538fac786e4b300700b9e70 Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 22 Apr 2026 15:54:58 +0100 Subject: [PATCH 51/52] review comments --- .../subscriptionclient/transport/ws_conn.go | 3 +++ .../subscriptionclient/transport/ws_transport.go | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go index 9ce29b20f3..6f7c6a71c6 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -75,6 +75,9 @@ func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConne return c } +// subscribe registers handler for a new subscription on this connection and +// sends the protocol-level subscribe message. The returned function unsubscribes +// and, if this was the last subscription, triggers the idle-close flow. func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request, handler common.Handler) (func(), error) { c.subsMu.Lock() diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index 3cbc9e0ee1..c1b7df7e68 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -121,6 +121,9 @@ func NewWSTransport(ctx context.Context, opts WSTransportOptions) *WSTransport { return t } +// Subscribe initiates a GraphQL subscription over WebSocket. It reuses an +// existing connection when one is available for the same endpoint, subprotocol, +// headers, and init payload, dialing a new one otherwise. func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { conn, err := t.getOrDial(ctx, opts) if err != nil { @@ -262,7 +265,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) abstractlogger.String("error", "subprotocol negotiation failed"), abstractlogger.Error(err), ) - wsConn.Close(websocket.StatusProtocolError, err.Error()) + _ = wsConn.Close(websocket.StatusProtocolError, err.Error()) return nil, err } @@ -275,7 +278,7 @@ func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) abstractlogger.String("error", "protocol init failed"), abstractlogger.Error(err), ) - wsConn.Close(websocket.StatusProtocolError, "init failed") + _ = wsConn.Close(websocket.StatusProtocolError, "init failed") return nil, fmt.Errorf("%w: %w", ErrInitFailed, err) } From 03de100137f1a15f3f29db1bd539c7922db21d6b Mon Sep 17 00:00:00 2001 From: endigma Date: Wed, 22 Apr 2026 16:01:18 +0100 Subject: [PATCH 52/52] document maps in wstransport --- .../subscriptionclient/transport/ws_transport.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go index c1b7df7e68..33c7a9b0af 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -80,6 +80,14 @@ type WSTransport struct { ctx context.Context opts WSTransportOptions + // mu guards both dialing and conns. + // + // dialing coalesces concurrent dial attempts: waiters block on + // dialResult.done rather than each dialing independently. + // + // conns holds only fully established connections (dial + protocol init + // complete) so every entry is always fully usable, never in a partial + // ready state. mu sync.Mutex dialing map[uint64]*dialResult conns map[uint64]*wsConnection