Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions server/pkg/websocket/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"

"github.com/cortezaproject/corteza/server/pkg/auth"
Expand Down Expand Up @@ -39,9 +40,10 @@ type (
session struct {
l sync.RWMutex

id uint64
once sync.Once
conn conection
id uint64
once sync.Once
conn conection
closed atomic.Bool

ctx context.Context
ctxCancel context.CancelFunc
Expand Down Expand Up @@ -92,6 +94,9 @@ func (s *session) disconnect() {
s.l.Lock()
defer s.l.Unlock()

// Mark session as closed before closing channels
s.closed.Store(true)

// Cancel context
s.ctxCancel()

Expand Down Expand Up @@ -190,6 +195,11 @@ func (s *session) read() (raw []byte, err error) {
s.l.RLock()
defer s.l.RUnlock()

// Check if connection was closed by disconnect()
if s.conn == nil {
return nil, net.ErrClosed
}

if _, raw, err = s.conn.ReadMessage(); err != nil {
return nil, errHandler("websocket read failed", err)
}
Expand Down Expand Up @@ -291,6 +301,11 @@ func (s *session) write(t int, msg []byte) (err error) {
}
}()

// Check if connection was closed by disconnect()
if s.conn == nil {
return net.ErrClosed
}

if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return fmt.Errorf("deadline error: %w", err)
}
Expand Down Expand Up @@ -324,13 +339,21 @@ func (s *session) authenticate(p *payloadAuth) error {

// sendBytes sends byte to channel or timeout
func (s *session) Write(p []byte) (int, error) {
// Check if session is closed before attempting to send
if s.closed.Load() {
return 0, net.ErrClosed
}

defer func() {
if recovered := recover(); recovered != nil {
s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered))
}
}()

select {
case <-s.ctx.Done():
// Session is disconnecting, channel may be closed
return 0, net.ErrClosed
case s.send <- p:
return len(p), nil
case <-time.After(2 * time.Millisecond):
Expand All @@ -343,14 +366,23 @@ func errHandler(prefix string, err error) error {
return nil
}

if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
// normal closing
return nil
// Handle websocket close errors - these are expected during disconnection
if websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
websocket.CloseNoStatusReceived,
) {
return net.ErrClosed
}

if errors.Is(err, net.ErrClosed) {
// suppress errors when reading/writing from/to a closed connection
return nil
return net.ErrClosed
}

// "close sent" occurs when writing to a connection that's closing
if err.Error() == "websocket: close sent" {
return net.ErrClosed
}

return fmt.Errorf(prefix+": %w", err)
Expand Down
Loading