Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle abnormal WebSocket close on client (fixes #114) #140

Merged
merged 5 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion codec/websocket/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ func (s *Stream) NextFrame() (f Frame, err error) {

func (s *Stream) nextFrame() (f Frame, err error) {
f, err = s.codecConn.ReadNext()

// If we get an EOF error, TCP stream was closed
// This is an abnormal closure from the server
if s.state != StateTerminated && err == io.EOF {
s.state = StateTerminated

// Prepare and return the 1006 close frame directly to client
f = NewFrame()
f.SetFIN().SetClose().SetPayload(EncodeCloseFramePayload(CloseAbnormal, ""))
}

if err == nil {
err = s.handleFrame(f)
}
Expand Down Expand Up @@ -273,9 +284,17 @@ func (s *Stream) asyncNextFrame(callback AsyncFrameCallback) {
s.codecConn.AsyncReadNext(func(err error, f Frame) {
if err == nil {
err = s.handleFrame(f)
} else if err == io.EOF {

// If we get an EOF error, TCP stream was closed
// This is an abnormal closure from the server
} else if s.state != StateTerminated && err == io.EOF {
s.state = StateTerminated

// Prepare and return the 1006 close frame directly
f = NewFrame()
f.SetFIN().SetClose().SetPayload(EncodeCloseFramePayload(CloseAbnormal, ""))
}

callback(err, f)
})
}
Expand Down
108 changes: 108 additions & 0 deletions codec/websocket/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"testing"
Expand Down Expand Up @@ -1639,3 +1640,110 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) {
}
})
}

func TestClientAbnormalClose(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
defer srv.Close()

err := srv.Accept("localhost:0", func(port int) {
if port <= 0 {
panic(fmt.Sprintf("Got invalid port from MockServer: %d", port))
}
portChan <- port
})
if err != nil {
panic(err)
}

// Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame)
srv.Close()
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()

ws, err := NewWebsocketStream(ioc, nil, RoleClient)
assert.Nil(err)

err = ws.Handshake(wsURI)
assert.Nil(err)
assert.Equal(ws.State(), StateActive) // Verify WebSocket active

MaxMaeder marked this conversation as resolved.
Show resolved Hide resolved
// Attempt to read a frame; this should return the 1006 close frame directly
frame, err := ws.NextFrame()
assert.Equal(io.EOF, err) // Verify frame error is EOF

assert.True(frame.Opcode().IsClose()) // Verify we got a close frame

closeCode, reason := DecodeCloseFramePayload(frame.Payload())
assert.Equal(CloseAbnormal, closeCode) // Verify the close code is 1006
assert.Empty(reason) // Verify there is no reason payload

assert.Equal(ws.State(), StateTerminated) // Verify the WebSocket's state
}

func TestClientAsyncAbnormalClose(t *testing.T) {
assert := assert.New(t)

portChan := make(chan int, 1)

go func() {
srv := &MockServer{}
defer srv.Close()

err := srv.Accept("localhost:0", func(port int) {
if port <= 0 {
panic(fmt.Sprintf("Got invalid port from MockServer: %d", port))
}
portChan <- port
})
if err != nil {
panic(err)
}

// Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame)
srv.Close()
}()
time.Sleep(10 * time.Millisecond)

wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan)

ioc := sonic.MustIO()
defer ioc.Close()

ws, err := NewWebsocketStream(ioc, nil, RoleClient)
assert.Nil(err)

done := false
ws.AsyncHandshake(wsURI, func(err error) {
assert.Nil(err)
assert.Equal(ws.State(), StateActive) // Verify WebSocket active

MaxMaeder marked this conversation as resolved.
Show resolved Hide resolved
// Attempt to read a frame; this should fail due to the server's abnormal closure
ws.AsyncNextFrame(func(err error, f Frame) {
assert.Equal(io.EOF, err) // Verify frame error is EOF

assert.True(f.Opcode().IsClose()) // Verify we got a close frame

closeCode, reason := DecodeCloseFramePayload(f.Payload())
assert.Equal(CloseAbnormal, closeCode) // Verify close code is 1006
assert.Empty(reason) // Verify there is no reason payload

assert.Equal(ws.State(), StateTerminated) // Verify the WebSocket's state

done = true
})
})

for !done {
ioc.PollOne()
}
}
13 changes: 11 additions & 2 deletions codec/websocket/test_main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ type MockServer struct {
Upgrade *http.Request
}

func (s *MockServer) Accept(addr string) (err error) {
// Accept starts the mock server, listening on the specified address.
// If a callback is provided, it is invoked with the assigned port.
func (s *MockServer) Accept(addr string, opts ...func(int)) (err error) {
s.ln, err = net.Listen("tcp", addr)
if err != nil {
return err
}
atomic.StoreInt32(&s.port, int32(s.ln.Addr().(*net.TCPAddr).Port))

port := int(s.ln.Addr().(*net.TCPAddr).Port)
atomic.StoreInt32(&s.port, int32(port))

// Call port callback if provided
if len(opts) > 0 && opts[0] != nil {
opts[0](port)
}

conn, err := s.ln.Accept()
if err != nil {
Expand Down