diff --git a/codec/websocket/stream.go b/codec/websocket/stream.go index dbcb02f..e45dec4 100644 --- a/codec/websocket/stream.go +++ b/codec/websocket/stream.go @@ -234,6 +234,17 @@ func (s *Stream) NextFrame() (f Frame, err error) { func (s *Stream) nextFrame() (f Frame, err error) { f, err = s.codecConn.ReadNext() + + // If we get an EOF error, TCP stream was closed + // This is an abnormal closure from the server + if s.state != StateTerminated && err == io.EOF { + s.state = StateTerminated + + // Prepare and return the 1006 close frame directly to client + f = NewFrame() + f.SetFIN().SetClose().SetPayload(EncodeCloseFramePayload(CloseAbnormal, "")) + } + if err == nil { err = s.handleFrame(f) } @@ -273,9 +284,17 @@ func (s *Stream) asyncNextFrame(callback AsyncFrameCallback) { s.codecConn.AsyncReadNext(func(err error, f Frame) { if err == nil { err = s.handleFrame(f) - } else if err == io.EOF { + + // If we get an EOF error, TCP stream was closed + // This is an abnormal closure from the server + } else if s.state != StateTerminated && err == io.EOF { s.state = StateTerminated + + // Prepare and return the 1006 close frame directly + f = NewFrame() + f.SetFIN().SetClose().SetPayload(EncodeCloseFramePayload(CloseAbnormal, "")) } + callback(err, f) }) } diff --git a/codec/websocket/stream_test.go b/codec/websocket/stream_test.go index b5e20ac..09d22c7 100644 --- a/codec/websocket/stream_test.go +++ b/codec/websocket/stream_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "errors" + "fmt" "io" "net/http" "testing" @@ -1639,3 +1640,110 @@ func TestClientAsyncCloseHandshakePeerStarts(t *testing.T) { } }) } + +func TestClientAbnormalClose(t *testing.T) { + assert := assert.New(t) + + portChan := make(chan int, 1) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:0", func(port int) { + if port <= 0 { + panic(fmt.Sprintf("Got invalid port from MockServer: %d", port)) + } + portChan <- port + }) + if err != nil { + panic(err) + } + + // Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame) + srv.Close() + }() + time.Sleep(10 * time.Millisecond) + + wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + assert.Nil(err) + + err = ws.Handshake(wsURI) + assert.Nil(err) + assert.Equal(ws.State(), StateActive) // Verify WebSocket active + + // Attempt to read a frame; this should return the 1006 close frame directly + frame, err := ws.NextFrame() + assert.Equal(io.EOF, err) // Verify frame error is EOF + + assert.True(frame.Opcode().IsClose()) // Verify we got a close frame + + closeCode, reason := DecodeCloseFramePayload(frame.Payload()) + assert.Equal(CloseAbnormal, closeCode) // Verify the close code is 1006 + assert.Empty(reason) // Verify there is no reason payload + + assert.Equal(ws.State(), StateTerminated) // Verify the WebSocket's state +} + +func TestClientAsyncAbnormalClose(t *testing.T) { + assert := assert.New(t) + + portChan := make(chan int, 1) + + go func() { + srv := &MockServer{} + defer srv.Close() + + err := srv.Accept("localhost:0", func(port int) { + if port <= 0 { + panic(fmt.Sprintf("Got invalid port from MockServer: %d", port)) + } + portChan <- port + }) + if err != nil { + panic(err) + } + + // Simulate an abnormal closure (close the TCP connection without sending a WebSocket close frame) + srv.Close() + }() + time.Sleep(10 * time.Millisecond) + + wsURI := fmt.Sprintf("ws://localhost:%d", <-portChan) + + ioc := sonic.MustIO() + defer ioc.Close() + + ws, err := NewWebsocketStream(ioc, nil, RoleClient) + assert.Nil(err) + + done := false + ws.AsyncHandshake(wsURI, func(err error) { + assert.Nil(err) + assert.Equal(ws.State(), StateActive) // Verify WebSocket active + + // Attempt to read a frame; this should fail due to the server's abnormal closure + ws.AsyncNextFrame(func(err error, f Frame) { + assert.Equal(io.EOF, err) // Verify frame error is EOF + + assert.True(f.Opcode().IsClose()) // Verify we got a close frame + + closeCode, reason := DecodeCloseFramePayload(f.Payload()) + assert.Equal(CloseAbnormal, closeCode) // Verify close code is 1006 + assert.Empty(reason) // Verify there is no reason payload + + assert.Equal(ws.State(), StateTerminated) // Verify the WebSocket's state + + done = true + }) + }) + + for !done { + ioc.PollOne() + } +} diff --git a/codec/websocket/test_main.go b/codec/websocket/test_main.go index da91cf4..368d841 100644 --- a/codec/websocket/test_main.go +++ b/codec/websocket/test_main.go @@ -22,12 +22,21 @@ type MockServer struct { Upgrade *http.Request } -func (s *MockServer) Accept(addr string) (err error) { +// Accept starts the mock server, listening on the specified address. +// If a callback is provided, it is invoked with the assigned port. +func (s *MockServer) Accept(addr string, opts ...func(int)) (err error) { s.ln, err = net.Listen("tcp", addr) if err != nil { return err } - atomic.StoreInt32(&s.port, int32(s.ln.Addr().(*net.TCPAddr).Port)) + + port := int(s.ln.Addr().(*net.TCPAddr).Port) + atomic.StoreInt32(&s.port, int32(port)) + + // Call port callback if provided + if len(opts) > 0 && opts[0] != nil { + opts[0](port) + } conn, err := s.ln.Accept() if err != nil {