Skip to content

Commit a8b66c2

Browse files
committed
wip: error handling improvements
1 parent 2bc0326 commit a8b66c2

12 files changed

Lines changed: 106 additions & 67 deletions

File tree

v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSu
150150

151151
msgCh, cancel, err := c.client.Subscribe(ctx.Context(), req, opts)
152152
if err != nil {
153+
if isUpstreamError(err) {
154+
updater.Error(formatUpstreamServiceError(err))
155+
updater.Done()
156+
return nil
157+
}
153158
return err
154159
}
155160

@@ -174,32 +179,28 @@ func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *clien
174179
return
175180
}
176181

177-
if msg.Err != nil {
178-
if isConnectionError(msg.Err) {
179-
updater.Error(formatUpstreamServiceError(msg.Err))
180-
} else {
181-
updater.Error(formatSubscriptionError(msg.Err))
182-
}
182+
switch msg.Type {
183+
case client.MessageTypeConnectionError:
184+
updater.Error(formatUpstreamServiceError(msg.Err))
185+
updater.Done()
186+
return
187+
188+
case client.MessageTypeError:
189+
data, _ := json.Marshal(msg.Payload)
190+
updater.Error(data)
183191
updater.Done()
184192
return
185-
}
186193

187-
if msg.Payload != nil {
194+
case client.MessageTypeData:
188195
data, err := json.Marshal(msg.Payload)
189196
if err != nil {
190197
updater.Error(formatSubscriptionError(err))
191198
updater.Done()
192199
return
193200
}
194-
if msg.Done {
195-
updater.Error(data)
196-
updater.Done()
197-
return
198-
}
199201
updater.Update(data)
200-
}
201202

202-
if msg.Done {
203+
case client.MessageTypeComplete:
203204
updater.Complete()
204205
updater.Done()
205206
return
@@ -208,9 +209,13 @@ func (c *subscriptionClientV2) readLoop(ctx context.Context, msgCh <-chan *clien
208209
}
209210
}
210211

211-
func isConnectionError(err error) bool {
212+
// isUpstreamError reports whether err is a connection-level upstream error
213+
// that should be reported to the client as an UPSTREAM_SERVICE_ERROR.
214+
func isUpstreamError(err error) bool {
212215
return errors.Is(err, client.ErrConnectionClosed) ||
213216
errors.Is(err, client.ErrConnectionError) ||
217+
errors.Is(err, client.ErrInitFailed) ||
218+
errors.Is(err, client.ErrDialFailed) ||
214219
errors.Is(err, context.Canceled) ||
215220
errors.Is(err, context.DeadlineExceeded)
216221
}
@@ -267,8 +272,9 @@ func mapWSSubprotocol(proto string) client.WSSubprotocol {
267272
}
268273

269274
// formatUpstreamServiceError formats a connection-level error as a GraphQL error
270-
// response with the UPSTREAM_SERVICE_ERROR extension code. If the error is a
271-
// WebSocket close error, the close code and reason are included in extensions.
275+
// response with the UPSTREAM_SERVICE_ERROR extension code. If the error chain
276+
// contains a WebSocket close error, the close code and reason are included in
277+
// extensions.
272278
func formatUpstreamServiceError(err error) []byte {
273279
type errorExtensions struct {
274280
Code string `json:"code"`
@@ -281,18 +287,21 @@ func formatUpstreamServiceError(err error) []byte {
281287
Extensions errorExtensions `json:"extensions"`
282288
}
283289

284-
ext := errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"}
290+
gqlErr := graphqlError{
291+
Message: "upstream service error",
292+
Extensions: errorExtensions{Code: "UPSTREAM_SERVICE_ERROR"},
293+
}
285294

286295
var closeErr websocket.CloseError
287296
if errors.As(err, &closeErr) {
288-
ext.CloseCode = int(closeErr.Code)
289-
ext.Reason = closeErr.Reason
297+
gqlErr.Extensions.CloseCode = int(closeErr.Code)
298+
gqlErr.Extensions.Reason = closeErr.Reason
290299
}
291300

292301
resp := struct {
293302
Errors []graphqlError `json:"errors"`
294303
}{
295-
Errors: []graphqlError{{Message: "upstream service closed the connection", Extensions: ext}},
304+
Errors: []graphqlError{gqlErr},
296305
}
297306
data, _ := json.Marshal(resp)
298307
return data

v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestReadLoopErrorHandling(t *testing.T) {
4747
t.Run("connection errors deliver error and done without updates", func(t *testing.T) {
4848
updater := &testBridgeUpdater{}
4949
msgCh := make(chan *client.Message, 1)
50-
msgCh <- &client.Message{Err: client.ErrConnectionClosed}
50+
msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed}
5151
close(msgCh)
5252

5353
subClient := &subscriptionClientV2{}
@@ -62,7 +62,7 @@ func TestReadLoopErrorHandling(t *testing.T) {
6262
t.Run("non-connection errors deliver error and done without updates", func(t *testing.T) {
6363
updater := &testBridgeUpdater{}
6464
msgCh := make(chan *client.Message, 1)
65-
msgCh <- &client.Message{Err: errors.New("validation failed")}
65+
msgCh <- &client.Message{Type: client.MessageTypeConnectionError, Err: errors.New("validation failed")}
6666
close(msgCh)
6767

6868
subClient := &subscriptionClientV2{}
@@ -102,7 +102,7 @@ func TestReadLoopErrorHandling(t *testing.T) {
102102
t.Run("done message calls complete then done", func(t *testing.T) {
103103
updater := &testBridgeUpdater{}
104104
msgCh := make(chan *client.Message, 1)
105-
msgCh <- &client.Message{Done: true}
105+
msgCh <- &client.Message{Type: client.MessageTypeComplete}
106106
close(msgCh)
107107

108108
subClient := &subscriptionClientV2{}

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,26 @@ import (
77

88
var ErrConnectionClosed = errors.New("connection closed")
99

10+
// MessageType identifies the kind of message delivered on a subscription channel.
11+
type MessageType int
12+
13+
const (
14+
MessageTypeUnknown MessageType = iota
15+
MessageTypeData // normal data payload
16+
MessageTypeError // GraphQL-level error from server (has Payload)
17+
MessageTypeComplete // subscription completed normally
18+
MessageTypeConnectionError // connection-level error (has Err)
19+
)
20+
21+
// IsTerminal reports whether the message type signals end-of-stream.
22+
func (t MessageType) IsTerminal() bool {
23+
return t == MessageTypeError || t == MessageTypeComplete || t == MessageTypeConnectionError
24+
}
25+
1026
type Message struct {
27+
Type MessageType
1128
Payload *ExecutionResult
12-
Err error
13-
Done bool
29+
Err error // only set when Type == MessageTypeConnectionError
1430
}
1531

1632
type ExecutionResult struct {

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
type (
1212
Message = common.Message
13+
MessageType = common.MessageType
1314
ExecutionResult = common.ExecutionResult
1415
Request = common.Request
1516
Options = common.Options
@@ -21,6 +22,12 @@ type (
2122
// Re-export constants.
2223

2324
const (
25+
MessageTypeUnknown = common.MessageTypeUnknown
26+
MessageTypeData = common.MessageTypeData
27+
MessageTypeError = common.MessageTypeError
28+
MessageTypeComplete = common.MessageTypeComplete
29+
MessageTypeConnectionError = common.MessageTypeConnectionError
30+
2431
TransportWS = common.TransportWS
2532
TransportSSE = common.TransportSSE
2633

@@ -48,4 +55,6 @@ var (
4855
ErrAckTimeout = protocol.ErrAckTimeout
4956
ErrAckNotReceived = protocol.ErrAckNotReceived
5057
ErrSubscriptionExists = transport.ErrSubscriptionExists
58+
ErrDialFailed = transport.ErrDialFailed
59+
ErrInitFailed = transport.ErrInitFailed
5160
)

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ type Message struct {
3939
func (m *Message) IntoClientMessage() *common.Message {
4040
switch m.Type {
4141
case MessageData:
42-
return &common.Message{Payload: m.Payload}
42+
return &common.Message{Type: common.MessageTypeData, Payload: m.Payload}
4343
case MessageError:
4444
if m.Payload != nil {
45-
return &common.Message{Payload: m.Payload, Done: true}
45+
return &common.Message{Type: common.MessageTypeError, Payload: m.Payload}
4646
}
47-
return &common.Message{Err: m.Err, Done: true}
47+
return &common.Message{Type: common.MessageTypeConnectionError, Err: m.Err}
4848
case MessageComplete:
49-
return &common.Message{Done: true}
49+
return &common.Message{Type: common.MessageTypeComplete}
5050
default:
51-
return &common.Message{}
51+
return &common.Message{Type: common.MessageTypeUnknown}
5252
}
5353
}
5454

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (c *sseConnection) readLoop() {
7171
return
7272
}
7373

74-
if msg.Done {
74+
if msg.Type.IsTerminal() {
7575
return
7676
}
7777
}
@@ -129,36 +129,30 @@ func (c *sseConnection) parseEvent(eventType string, data []byte) *common.Messag
129129
case "next":
130130
var resp common.ExecutionResult
131131
if err := json.Unmarshal(data, &resp); err != nil {
132-
return &common.Message{
133-
Err: err,
134-
Done: true,
135-
}
132+
return &common.Message{Type: common.MessageTypeConnectionError, Err: err}
136133
}
137-
return &common.Message{Payload: &resp}
134+
return &common.Message{Type: common.MessageTypeData, Payload: &resp}
138135

139136
case "error":
140137
return &common.Message{
138+
Type: common.MessageTypeError,
141139
Payload: &common.ExecutionResult{Errors: data},
142-
Done: true,
143140
}
144141

145142
case "complete":
146-
return &common.Message{Done: true}
143+
return &common.Message{Type: common.MessageTypeComplete}
147144

148145
default:
149146
// Unknown event type or no event type specified - treat as data
150147
// This handles servers that send data without an event type
151148
if len(data) == 0 {
152-
return &common.Message{Done: true}
149+
return &common.Message{Type: common.MessageTypeComplete}
153150
}
154151
var resp common.ExecutionResult
155152
if err := json.Unmarshal(data, &resp); err != nil {
156-
return &common.Message{
157-
Err: err,
158-
Done: true,
159-
}
153+
return &common.Message{Type: common.MessageTypeConnectionError, Err: err}
160154
}
161-
return &common.Message{Payload: &resp}
155+
return &common.Message{Type: common.MessageTypeData, Payload: &resp}
162156
}
163157
}
164158

@@ -167,7 +161,7 @@ func (c *sseConnection) sendError(err error) {
167161
return
168162
}
169163
select {
170-
case c.ch <- &common.Message{Err: err, Done: true}:
164+
case c.ch <- &common.Message{Type: common.MessageTypeConnectionError, Err: err}:
171165
case <-c.done:
172166
}
173167
}

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010

1111
"github.com/stretchr/testify/assert"
1212
"github.com/stretchr/testify/require"
13+
14+
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common"
1315
)
1416

1517
func TestSSEConnection_ReadLoop(t *testing.T) {
@@ -52,7 +54,7 @@ func TestSSEConnection_ReadLoop(t *testing.T) {
5254

5355
msg := <-conn.ch
5456
require.Error(t, msg.Err)
55-
require.True(t, msg.Done)
57+
assert.Equal(t, common.MessageTypeConnectionError, msg.Type)
5658
})
5759

5860
t.Run("stops on complete event", func(t *testing.T) {
@@ -69,11 +71,11 @@ func TestSSEConnection_ReadLoop(t *testing.T) {
6971
// First message
7072
msg1 := <-conn.ch
7173
assert.NotNil(t, msg1.Payload)
72-
assert.False(t, msg1.Done)
74+
assert.Equal(t, common.MessageTypeData, msg1.Type)
7375

7476
// Complete message
7577
msg2 := <-conn.ch
76-
assert.True(t, msg2.Done)
78+
assert.Equal(t, common.MessageTypeComplete, msg2.Type)
7779

7880
// Channel should close, no third message
7981
select {

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func TestSSETransport_Subscribe(t *testing.T) {
7777

7878
// Receive complete message
7979
msg = receiveWithTimeout(t, ch, time.Second)
80-
assert.True(t, msg.Done)
80+
assert.Equal(t, common.MessageTypeComplete, msg.Type)
8181
})
8282

8383
t.Run("passes custom headers", func(t *testing.T) {
@@ -142,7 +142,7 @@ func TestSSETransport_Subscribe(t *testing.T) {
142142
msg := receiveWithTimeout(t, ch, time.Second)
143143
require.NotNil(t, msg.Payload)
144144
assert.Contains(t, string(msg.Payload.Data), "Alice")
145-
assert.False(t, msg.Done)
145+
assert.Equal(t, common.MessageTypeData, msg.Type)
146146
})
147147

148148
t.Run("handles error event", func(t *testing.T) {
@@ -164,7 +164,7 @@ func TestSSETransport_Subscribe(t *testing.T) {
164164
defer cancel()
165165

166166
msg := receiveWithTimeout(t, ch, time.Second)
167-
assert.True(t, msg.Done)
167+
assert.Equal(t, common.MessageTypeError, msg.Type)
168168
require.NotNil(t, msg.Payload)
169169
assert.Contains(t, string(msg.Payload.Errors), "Something went wrong")
170170
})
@@ -188,7 +188,7 @@ func TestSSETransport_Subscribe(t *testing.T) {
188188
defer cancel()
189189

190190
msg := receiveWithTimeout(t, ch, time.Second)
191-
assert.True(t, msg.Done)
191+
assert.Equal(t, common.MessageTypeComplete, msg.Type)
192192
assert.Nil(t, msg.Err)
193193
assert.Nil(t, msg.Payload)
194194
})
@@ -263,7 +263,7 @@ func TestSSETransport_Subscribe(t *testing.T) {
263263
// Should only receive 2 messages (next + complete), not comments
264264
for msg := range ch {
265265
messageCount.Add(1)
266-
if msg.Done {
266+
if msg.Type.IsTerminal() {
267267
break
268268
}
269269
}
@@ -616,7 +616,7 @@ func TestSSETransport_ContentTypeValidation(t *testing.T) {
616616
defer cancel()
617617

618618
msg := receiveWithTimeout(t, ch, time.Second)
619-
assert.True(t, msg.Done)
619+
assert.Equal(t, common.MessageTypeComplete, msg.Type)
620620
})
621621

622622
t.Run("rejects non-SSE content type", func(t *testing.T) {
@@ -701,7 +701,7 @@ func TestSSETransport_GETMethod(t *testing.T) {
701701

702702
// Receive complete message
703703
msg = receiveWithTimeout(t, ch, time.Second)
704-
assert.True(t, msg.Done)
704+
assert.Equal(t, common.MessageTypeComplete, msg.Type)
705705
})
706706

707707
t.Run("GET preserves existing query parameters", func(t *testing.T) {

v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func (c *wsConnection) shutdown(err error) {
246246
c.subs = make(map[string]chan<- *common.Message)
247247
c.subsMu.Unlock()
248248

249-
errMsg := &common.Message{Err: err, Done: true}
249+
errMsg := &common.Message{Type: common.MessageTypeConnectionError, Err: err}
250250
for _, ch := range subs {
251251
select {
252252
case ch <- errMsg:

0 commit comments

Comments
 (0)