diff --git a/server/pkg/websocket/session.go b/server/pkg/websocket/session.go index bbdef8b551..89724c5938 100644 --- a/server/pkg/websocket/session.go +++ b/server/pkg/websocket/session.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/cortezaproject/corteza/server/pkg/auth" @@ -39,9 +40,10 @@ type ( session struct { l sync.RWMutex - id uint64 - once sync.Once - conn conection + id uint64 + once sync.Once + conn conection + closed atomic.Bool ctx context.Context ctxCancel context.CancelFunc @@ -92,6 +94,9 @@ func (s *session) disconnect() { s.l.Lock() defer s.l.Unlock() + // Mark session as closed before closing channels + s.closed.Store(true) + // Cancel context s.ctxCancel() @@ -190,6 +195,11 @@ func (s *session) read() (raw []byte, err error) { s.l.RLock() defer s.l.RUnlock() + // Check if connection was closed by disconnect() + if s.conn == nil { + return nil, net.ErrClosed + } + if _, raw, err = s.conn.ReadMessage(); err != nil { return nil, errHandler("websocket read failed", err) } @@ -291,6 +301,11 @@ func (s *session) write(t int, msg []byte) (err error) { } }() + // Check if connection was closed by disconnect() + if s.conn == nil { + return net.ErrClosed + } + if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { return fmt.Errorf("deadline error: %w", err) } @@ -324,6 +339,11 @@ func (s *session) authenticate(p *payloadAuth) error { // sendBytes sends byte to channel or timeout func (s *session) Write(p []byte) (int, error) { + // Check if session is closed before attempting to send + if s.closed.Load() { + return 0, net.ErrClosed + } + defer func() { if recovered := recover(); recovered != nil { s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered)) @@ -331,6 +351,9 @@ func (s *session) Write(p []byte) (int, error) { }() select { + case <-s.ctx.Done(): + // Session is disconnecting, channel may be closed + return 0, net.ErrClosed case s.send <- p: return len(p), nil case <-time.After(2 * time.Millisecond): @@ -343,14 +366,23 @@ func errHandler(prefix string, err error) error { return nil } - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - // normal closing - return nil + // Handle websocket close errors - these are expected during disconnection + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + websocket.CloseNoStatusReceived, + ) { + return net.ErrClosed } if errors.Is(err, net.ErrClosed) { - // suppress errors when reading/writing from/to a closed connection - return nil + return net.ErrClosed + } + + // "close sent" occurs when writing to a connection that's closing + if err.Error() == "websocket: close sent" { + return net.ErrClosed } return fmt.Errorf(prefix+": %w", err)