@@ -131,3 +131,58 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
131131func (f roundTripperFunc ) RoundTrip (req * http.Request ) (* http.Response , error ) {
132132 return f (req )
133133}
134+
135+ func TestSSEClientTransport_HTTPErrors (t * testing.T ) {
136+ tests := []struct {
137+ name string
138+ statusCode int
139+ wantErrContain string
140+ }{
141+ {
142+ name : "401 Unauthorized" ,
143+ statusCode : http .StatusUnauthorized ,
144+ wantErrContain : "Unauthorized" ,
145+ },
146+ {
147+ name : "403 Forbidden" ,
148+ statusCode : http .StatusForbidden ,
149+ wantErrContain : "Forbidden" ,
150+ },
151+ {
152+ name : "404 Not Found" ,
153+ statusCode : http .StatusNotFound ,
154+ wantErrContain : "Not Found" ,
155+ },
156+ {
157+ name : "500 Internal Server Error" ,
158+ statusCode : http .StatusInternalServerError ,
159+ wantErrContain : "Internal Server Error" ,
160+ },
161+ }
162+
163+ for _ , tt := range tests {
164+ t .Run (tt .name , func (t * testing.T ) {
165+ // Create a test server that returns the specified status code
166+ httpServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
167+ http .Error (w , http .StatusText (tt .statusCode ), tt .statusCode )
168+ }))
169+ defer httpServer .Close ()
170+
171+ clientTransport := & SSEClientTransport {
172+ Endpoint : httpServer .URL ,
173+ }
174+
175+ c := NewClient (testImpl , nil )
176+ _ , err := c .Connect (context .Background (), clientTransport , nil )
177+
178+ if err == nil {
179+ t .Fatalf ("expected error, got nil" )
180+ }
181+
182+ errStr := err .Error ()
183+ if ! bytes .Contains ([]byte (errStr ), []byte (tt .wantErrContain )) {
184+ t .Errorf ("error message %q does not contain %q" , errStr , tt .wantErrContain )
185+ }
186+ })
187+ }
188+ }
0 commit comments