Skip to content

Commit 27b3354

Browse files
authored
mcp: fix SSEClientTransport to report HTTP errors properly (#740)
Fix SSEClientTransport.Connect to check HTTP status code before attempting to parse SSE events. Fixes #714
1 parent 1d5938c commit 27b3354

2 files changed

Lines changed: 63 additions & 0 deletions

File tree

mcp/sse.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,14 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
351351
return nil, err
352352
}
353353

354+
// Check HTTP status code before attempting to parse SSE events.
355+
// This ensures proper error reporting for authentication failures (401),
356+
// authorization failures (403), and other HTTP errors.
357+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
358+
resp.Body.Close()
359+
return nil, fmt.Errorf("failed to connect: %s", http.StatusText(resp.StatusCode))
360+
}
361+
354362
msgEndpoint, err := func() (*url.URL, error) {
355363
var evt Event
356364
for evt, err = range scanEvents(resp.Body) {

mcp/sse_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,58 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
131131
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
132132
return f(req)
133133
}
134+
135+
func TestSSEClientTransport_HTTPErrors(t *testing.T) {
136+
tests := []struct {
137+
name string
138+
statusCode int
139+
wantErrContain string
140+
}{
141+
{
142+
name: "401 Unauthorized",
143+
statusCode: http.StatusUnauthorized,
144+
wantErrContain: "Unauthorized",
145+
},
146+
{
147+
name: "403 Forbidden",
148+
statusCode: http.StatusForbidden,
149+
wantErrContain: "Forbidden",
150+
},
151+
{
152+
name: "404 Not Found",
153+
statusCode: http.StatusNotFound,
154+
wantErrContain: "Not Found",
155+
},
156+
{
157+
name: "500 Internal Server Error",
158+
statusCode: http.StatusInternalServerError,
159+
wantErrContain: "Internal Server Error",
160+
},
161+
}
162+
163+
for _, tt := range tests {
164+
t.Run(tt.name, func(t *testing.T) {
165+
// Create a test server that returns the specified status code
166+
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
167+
http.Error(w, http.StatusText(tt.statusCode), tt.statusCode)
168+
}))
169+
defer httpServer.Close()
170+
171+
clientTransport := &SSEClientTransport{
172+
Endpoint: httpServer.URL,
173+
}
174+
175+
c := NewClient(testImpl, nil)
176+
_, err := c.Connect(context.Background(), clientTransport, nil)
177+
178+
if err == nil {
179+
t.Fatalf("expected error, got nil")
180+
}
181+
182+
errStr := err.Error()
183+
if !bytes.Contains([]byte(errStr), []byte(tt.wantErrContain)) {
184+
t.Errorf("error message %q does not contain %q", errStr, tt.wantErrContain)
185+
}
186+
})
187+
}
188+
}

0 commit comments

Comments
 (0)