Skip to content

Commit 1fc0dd9

Browse files
committed
fix: improve ws subprotocol selection
1 parent 22c2f9a commit 1fc0dd9

10 files changed

Lines changed: 37 additions & 36 deletions

execution/engine/config_factory_federation.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,15 +447,13 @@ func (f *FederationEngineConfigFactory) subscriptionClient(
447447
httpClient,
448448
streamingClient,
449449
f.engineCtx,
450-
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLTWS),
451450
)
452451
default:
453452
// for compatibility reasons we fall back to graphql-ws protocol
454453
graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient(
455454
httpClient,
456455
streamingClient,
457456
f.engineCtx,
458-
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLWS),
459457
)
460458
}
461459

execution/engine/engine_config.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,13 @@ func (d *graphqlDataSourceGenerator) generateSubscriptionClient(httpClient *http
150150
httpClient,
151151
definedOptions.streamingClient,
152152
nil,
153-
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLTWS),
154153
)
155154
default:
156155
// for compatibility reasons we fall back to graphql-ws protocol
157156
graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient(
158157
httpClient,
159158
definedOptions.streamingClient,
160159
nil,
161-
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLWS),
162160
)
163161
}
164162

v2/pkg/engine/datasource/graphql_datasource/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ type SubscriptionConfiguration struct {
106106
// which connections can be multiplexed together, but the subscription engine does not forward
107107
// these headers by itself.
108108
ForwardedClientHeaderRegularExpressions []*regexp.Regexp
109+
WsSubProtocol string
109110
}
110111

111112
type FetchConfiguration struct {

v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration {
362362
input = httpclient.SetInputFlag(input, httpclient.SSE_METHOD_POST)
363363
}
364364
}
365+
input = httpclient.SetInputWSSubprotocol(input, []byte(p.config.subscription.WsSubProtocol))
365366

366367
header, err := json.Marshal(p.config.subscription.Header)
367368
if err == nil && len(header) != 0 && !bytes.Equal(header, literal.NULL) {
@@ -1668,6 +1669,7 @@ type GraphQLSubscriptionOptions struct {
16681669
SSEMethodPost bool `json:"sse_method_post"`
16691670
ForwardedClientHeaderNames []string `json:"forwarded_client_header_names"`
16701671
ForwardedClientHeaderRegularExpressions []*regexp.Regexp `json:"forwarded_client_header_regular_expressions"`
1672+
WsSubProtocol string `json:"ws_sub_protocol"`
16711673
}
16721674

16731675
type GraphQLBody struct {

v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8851,7 +8851,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) {
88518851
newSubscriptionSource := func(ctx context.Context) SubscriptionSource {
88528852
httpClient := http.Client{}
88538853
subscriptionSource := SubscriptionSource{
8854-
client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx, WithWSSubProtocol(ProtocolGraphQLTWS)),
8854+
client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx),
88558855
}
88568856
return subscriptionSource
88578857
}

v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,25 @@ type subscriptionClient struct {
3030
hashPool sync.Pool
3131
handlers map[uint64]ConnectionHandler
3232
handlersMu sync.Mutex
33-
wsSubProtocol string
3433
onWsConnectionInitCallback *OnWsConnectionInitCallback
3534

3635
readTimeout time.Duration
3736
}
3837

38+
type InvalidWsSubprotocolError struct {
39+
Message string
40+
}
41+
42+
func (e InvalidWsSubprotocolError) Error() string {
43+
return e.Message
44+
}
45+
46+
func NewInvalidWsSubprotocolError(message string) InvalidWsSubprotocolError {
47+
return InvalidWsSubprotocolError{
48+
Message: message,
49+
}
50+
}
51+
3952
type Options func(options *opts)
4053

4154
func WithLogger(log abstractlogger.Logger) Options {
@@ -50,12 +63,6 @@ func WithReadTimeout(timeout time.Duration) Options {
5063
}
5164
}
5265

53-
func WithWSSubProtocol(protocol string) Options {
54-
return func(options *opts) {
55-
options.wsSubProtocol = protocol
56-
}
57-
}
58-
5966
func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Options {
6067
return func(options *opts) {
6168
options.onWsConnectionInitCallback = callback
@@ -65,7 +72,6 @@ func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Option
6572
type opts struct {
6673
readTimeout time.Duration
6774
log abstractlogger.Logger
68-
wsSubProtocol string
6975
onWsConnectionInitCallback *OnWsConnectionInitCallback
7076
}
7177

@@ -106,7 +112,6 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
106112
return xxhash.New()
107113
},
108114
},
109-
wsSubProtocol: op.wsSubProtocol,
110115
onWsConnectionInitCallback: op.onWsConnectionInitCallback,
111116
}
112117
}
@@ -288,8 +293,8 @@ func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSu
288293

289294
func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, options GraphQLSubscriptionOptions) (ConnectionHandler, error) {
290295
subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS}
291-
if c.wsSubProtocol != "" {
292-
subProtocols = []string{c.wsSubProtocol}
296+
if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" {
297+
subProtocols = []string{options.WsSubProtocol}
293298
}
294299

295300
conn, upgradeResponse, err := websocket.Dial(reqCtx, options.URL, &websocket.DialOptions{
@@ -333,21 +338,25 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti
333338
return nil, err
334339
}
335340

336-
if c.wsSubProtocol == "" {
337-
c.wsSubProtocol = conn.Subprotocol()
341+
wsSubProtocol := subProtocols[0]
342+
if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" {
343+
wsSubProtocol = conn.Subprotocol()
344+
if wsSubProtocol == "" {
345+
wsSubProtocol = ProtocolGraphQLWS
346+
}
338347
}
339348

340349
if err := waitForAck(reqCtx, conn); err != nil {
341350
return nil, err
342351
}
343352

344-
switch c.wsSubProtocol {
353+
switch wsSubProtocol {
345354
case ProtocolGraphQLWS:
346355
return newGQLWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil
347356
case ProtocolGraphQLTWS:
348357
return newGQLTWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil
349358
default:
350-
return nil, fmt.Errorf("unknown protocol %s", conn.Subprotocol())
359+
return nil, NewInvalidWsSubprotocolError(fmt.Sprintf("provided websocket subprotocol %s is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subsciptions with the mentioned subprotocols", wsSubProtocol))
351360
}
352361
}
353362

v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) {
156156
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
157157
WithReadTimeout(time.Millisecond),
158158
WithLogger(logger()),
159-
WithWSSubProtocol(ProtocolGraphQLWS),
160159
)
161160
clientsDone := &sync.WaitGroup{}
162161

@@ -215,7 +214,6 @@ func TestWebsocketSubscriptionClientImmediateClientCancel(t *testing.T) {
215214
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
216215
WithReadTimeout(time.Millisecond),
217216
WithLogger(logger()),
218-
WithWSSubProtocol(ProtocolGraphQLWS),
219217
).(*subscriptionClient)
220218
updater := &testSubscriptionUpdater{}
221219
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
@@ -270,7 +268,6 @@ func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) {
270268
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
271269
WithReadTimeout(time.Millisecond),
272270
WithLogger(logger()),
273-
WithWSSubProtocol(ProtocolGraphQLWS),
274271
).(*subscriptionClient)
275272
updater := &testSubscriptionUpdater{}
276273
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{

v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) {
6363
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
6464
WithReadTimeout(time.Millisecond),
6565
WithLogger(logger()),
66-
WithWSSubProtocol(ProtocolGraphQLTWS),
6766
).(*subscriptionClient)
6867

6968
updater := &testSubscriptionUpdater{}
@@ -142,7 +141,6 @@ func TestWebsocketSubscriptionClientPing_GQLTWS(t *testing.T) {
142141
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
143142
WithReadTimeout(time.Millisecond),
144143
WithLogger(logger()),
145-
WithWSSubProtocol(ProtocolGraphQLTWS),
146144
).(*subscriptionClient)
147145

148146
updater := &testSubscriptionUpdater{}
@@ -210,7 +208,6 @@ func TestWebsocketSubscriptionClientError_GQLTWS(t *testing.T) {
210208
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
211209
WithReadTimeout(time.Millisecond),
212210
WithLogger(logger()),
213-
WithWSSubProtocol(ProtocolGraphQLTWS),
214211
)
215212

216213
updater := &testSubscriptionUpdater{}
@@ -298,7 +295,6 @@ func TestWebSocketSubscriptionClientInitIncludePing_GQLTWS(t *testing.T) {
298295
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
299296
WithReadTimeout(time.Millisecond),
300297
WithLogger(logger()),
301-
WithWSSubProtocol(ProtocolGraphQLTWS),
302298
).(*subscriptionClient)
303299
updater := &testSubscriptionUpdater{}
304300
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
@@ -373,7 +369,6 @@ func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) {
373369
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
374370
WithReadTimeout(time.Second),
375371
WithLogger(logger()),
376-
WithWSSubProtocol(ProtocolGraphQLTWS),
377372
).(*subscriptionClient)
378373

379374
updater := &testSubscriptionUpdater{}

v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) {
7676
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
7777
WithReadTimeout(time.Millisecond),
7878
WithLogger(logger()),
79-
WithWSSubProtocol(ProtocolGraphQLWS),
8079
).(*subscriptionClient)
8180
updater := &testSubscriptionUpdater{}
8281
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
@@ -144,7 +143,6 @@ func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) {
144143
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
145144
WithReadTimeout(time.Millisecond),
146145
WithLogger(logger()),
147-
WithWSSubProtocol(ProtocolGraphQLWS),
148146
).(*subscriptionClient)
149147
updater := &testSubscriptionUpdater{}
150148
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
@@ -208,7 +206,6 @@ func TestWebsocketSubscriptionClientErrorArray(t *testing.T) {
208206
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
209207
WithReadTimeout(time.Millisecond),
210208
WithLogger(logger()),
211-
WithWSSubProtocol(ProtocolGraphQLWS),
212209
)
213210
updater := &testSubscriptionUpdater{}
214211
err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{
@@ -264,7 +261,6 @@ func TestWebsocketSubscriptionClientErrorObject(t *testing.T) {
264261
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
265262
WithReadTimeout(time.Millisecond),
266263
WithLogger(logger()),
267-
WithWSSubProtocol(ProtocolGraphQLWS),
268264
)
269265
updater := &testSubscriptionUpdater{}
270266
err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{
@@ -329,7 +325,6 @@ func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) {
329325
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
330326
WithReadTimeout(time.Second),
331327
WithLogger(logger()),
332-
WithWSSubProtocol(ProtocolGraphQLWS),
333328
).(*subscriptionClient)
334329
updater := &testSubscriptionUpdater{}
335330
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
@@ -381,7 +376,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
381376
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
382377
WithReadTimeout(time.Millisecond),
383378
WithLogger(logger()),
384-
WithWSSubProtocol(ProtocolGraphQLWS),
385379
).(*subscriptionClient)
386380

387381
updater := &testSubscriptionUpdater{}
@@ -432,7 +426,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
432426
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
433427
WithReadTimeout(time.Millisecond),
434428
WithLogger(logger()),
435-
WithWSSubProtocol(ProtocolGraphQLWS),
436429
).(*subscriptionClient)
437430

438431
updater := &testSubscriptionUpdater{}
@@ -471,7 +464,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
471464
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
472465
WithReadTimeout(time.Millisecond),
473466
WithLogger(logger()),
474-
WithWSSubProtocol(ProtocolGraphQLWS),
475467
).(*subscriptionClient)
476468

477469
updater := &testSubscriptionUpdater{}

v2/pkg/engine/datasource/httpclient/httpclient.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ const (
3333
FORWARDED_CLIENT_HEADER_NAMES = "forwarded_client_header_names"
3434
FORWARDED_CLIENT_HEADER_REGULAR_EXPRESSIONS = "forwarded_client_header_regular_expressions"
3535
TRACE = "__trace__"
36+
WsSubProtocol = "ws_sub_protocol"
3637
)
3738

3839
var (
@@ -118,6 +119,14 @@ func SetInputFlag(input []byte, flagName string) []byte {
118119
return out
119120
}
120121

122+
func SetInputWSSubprotocol(input, wsSubProtocol []byte) []byte {
123+
if len(wsSubProtocol) == 0 {
124+
return input
125+
}
126+
out, _ := sjson.SetRawBytes(input, WsSubProtocol, wrapQuotesIfString(wsSubProtocol))
127+
return out
128+
}
129+
121130
func IsInputFlagSet(input []byte, flagName string) bool {
122131
value, dataType, _, err := jsonparser.Get(input, flagName)
123132
if err != nil {

0 commit comments

Comments
 (0)