diff --git a/.changelog/19932.txt b/.changelog/19932.txt new file mode 100644 index 00000000000..aabb4aa2bd8 --- /dev/null +++ b/.changelog/19932.txt @@ -0,0 +1,3 @@ +```release-note:bug +exec: Fixed a bug in `alloc exec` where closing websocket streams could cause a panic +``` diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index 4b259fb897b..ff19e65b2ff 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -6,6 +6,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -516,7 +517,7 @@ func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *ht return nil, err } - return s.execStreamImpl(conn, &args) + return s.execStream(conn, &args) } // readWsHandshake reads the websocket handshake message and sets @@ -552,7 +553,9 @@ type wsHandshakeMessage struct { AuthToken string `json:"auth_token"` } -func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest) (interface{}, error) { +// execStream finds the appropriate RPC handler and then runs the bidirectional +// websocket-to-RPC stream +func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) { allocID := args.AllocID method := "Allocations.Exec" @@ -572,6 +575,13 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec return nil, CodedError(500, handlerErr.Error()) } + return s.execStreamImpl(ws, args, handler) +} + +// execStreamImpl is called by execStream with the appropriate RPC handler and +// then runs the bidirectional websocket-to-RPC stream. +func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) { + // Create a pipe connecting the (possibly remote) handler to the http response httpPipe, handlerPipe := net.Pipe() decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle) @@ -586,33 +596,37 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec // don't close ws - wait to drain messages }() - // Create a channel that decodes the results - errCh := make(chan HTTPCodedError, 2) + // Create a channel for the final result + resultCh := make(chan HTTPCodedError, 1) - // stream response + // stream response back to the websocket: this should be the only goroutine + // that writes to this websocket connection go func() { defer cancel() + errCh := make(chan HTTPCodedError, 2) // Send the request if err := encoder.Encode(args); err != nil { - errCh <- CodedError(500, err.Error()) + resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error())) return } - go forwardExecInput(encoder, ws, errCh) + // only start this after we've tried to send the initial args + go forwardExecInput(ctx, encoder, ws, errCh) for { - var res cstructs.StreamErrWrapper - err := decoder.Decode(&res) - if isClosedError(err) { - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - errCh <- nil + select { + case codedErr := <-errCh: + resultCh <- s.execStreamHandleError(ws, codedErr) return + default: } + var res cstructs.StreamErrWrapper + err := decoder.Decode(&res) if err != nil { errCh <- CodedError(500, err.Error()) - return + continue } decoder.Reset(httpPipe) @@ -622,39 +636,47 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec code = int(*err.Code) } errCh <- CodedError(code, err.Error()) - return + continue } - if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil { errCh <- CodedError(500, err.Error()) - return + continue } } }() - // start streaming request to streaming RPC - returns when streaming completes or errors + // start streaming request to streaming RPC - returns when streaming + // completes or errors handler(handlerPipe) - // stop streaming background goroutines for streaming - but not websocket activity + + // stop streaming background goroutines for streaming - but not websocket + // activity cancel() - // retrieve any error and/or wait until goroutine stop and close errCh connection before - // closing websocket connection - codedErr := <-errCh + // retrieve any error and/or wait until goroutine stop and close errCh + // connection before closing websocket connection + result := <-resultCh + ws.Close() + return nil, result +} + +// execStreamHandleError writes a CloseMessage to the websocket if we get an +// error that isn't a "close error" caused by the RPC pipe finishing up. Note +// that this should *only* ever be called in the same goroutine as we're +// streaming the responses +func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError { // we won't return an error on ws close, but at least make it available in // the logs so we can trace spurious disconnects - if codedErr != nil { - s.logger.Debug("alloc exec channel closed with error", "error", codedErr) - } + s.logger.Trace("alloc exec channel closed with error", "error", codedErr) if isClosedError(codedErr) { - codedErr = nil + return nil // we're intentionally throwing this error away } else if codedErr != nil { ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error())) + return codedErr } - ws.Close() - - return nil, codedErr + return nil } func toWsCode(httpCode int) int { @@ -667,30 +689,34 @@ func toWsCode(httpCode int) int { } } +// isClosedError checks if the websocket "error" is one of the benign "close" status codes func isClosedError(err error) bool { if err == nil { return false } - // check if the websocket "error" is one of the benign "close" status codes - if codedErr, ok := err.(HTTPCodedError); ok { - return slices.ContainsFunc([]string{ + return errors.Is(err, io.EOF) || + errors.Is(err, io.ErrClosedPipe) || + err == io.ErrClosedPipe || + slices.ContainsFunc([]string{ + "closed", // msgpack decode error [pos 0]: io: read/write on closed pipe" + "EOF", "close 1000", // CLOSE_NORMAL "close 1001", // CLOSE_GOING_AWAY "close 1005", // CLOSED_NO_STATUS - }, func(s string) bool { return strings.Contains(codedErr.Error(), s) }) - } - - return err == io.EOF || - err == io.ErrClosedPipe || - strings.Contains(err.Error(), "closed") || - strings.Contains(err.Error(), "EOF") + }, func(s string) bool { return strings.Contains(err.Error(), s) }) } // forwardExecInput forwards exec input (e.g. stdin) from websocket connection // to the streaming RPC connection to client -func forwardExecInput(encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) { +func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) { for { + select { + case <-ctx.Done(): + return + default: + } + sf := &drivers.ExecTaskStreamingRequestMsg{} err := ws.ReadJSON(sf) if err == io.EOF { diff --git a/command/agent/alloc_endpoint_test.go b/command/agent/alloc_endpoint_test.go index 7c0d730df75..8af8d154a75 100644 --- a/command/agent/alloc_endpoint_test.go +++ b/command/agent/alloc_endpoint_test.go @@ -5,6 +5,7 @@ package agent import ( "archive/tar" + "context" "fmt" "io" "net/http" @@ -14,16 +15,22 @@ import ( "strconv" "strings" "testing" + "time" "github.com/golang/snappy" + "github.com/gorilla/websocket" + "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/allocdir" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -1123,3 +1130,108 @@ func TestHTTP_ReadWsHandshake(t *testing.T) { }) } } + +// TestHTTP_AllocsExecStream_SafeClose verifies that we are safely closing the +// AllocExec stream when we're done without making concurrent writes to the +// websocket that can cause a panic +func TestHTTP_AllocsExecStream_SafeClose(t *testing.T) { + httpTest(t, + func(c *Config) { c.Server.NumSchedulers = pointer.Of(0) }, + func(s *TestAgent) { + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + rpcHandler := mockStreamingRpcHandler(t, [][]byte{ + []byte("one"), []byte("two"), []byte("done!")}) + + // This replaces the top-level HTTP handler, which is not under test + // here. It will call execStreamImpl using the mock streaming RPC + // handler defined above. + wsHandler := func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + must.NoError(t, err, must.Sprint("during ws upgrade")) + return + } + defer conn.Close() + + args := cstructs.AllocExecRequest{ + AllocID: uuid.Generate(), + Task: "foo", + Cmd: []string{"bar"}, + } + + _, err = s.Server.execStreamImpl(conn, &args, rpcHandler) + must.NoError(t, err) + } + + // Spin up a HTTP server that only handles our websocket + srv := httptest.NewServer(http.HandlerFunc(wsHandler)) + t.Cleanup(srv.Close) + u := strings.Replace(srv.URL, "http://", "ws://", 1) + conn, _, err := websocket.DefaultDialer.Dial(u, nil) + must.NoError(t, err, must.Sprint("failed to dial")) + defer conn.Close() + + drainResp := func() []string { + resp := []string{} + for { + select { + case <-ctx.Done(): + return resp + default: + _, message, err := conn.ReadMessage() + if err != nil { + if !isClosedError(err) { + resp = append(resp, err.Error()) + return resp + } + return resp + } + resp = append(resp, string(message)) + } + } + } + + must.Eq(t, []string{"one", "two", "done!"}, drainResp()) + }) +} + +// mockStreamingRpcHandler returns a function that can stand in for any +// structs.StreamingRpcHandler and streams the slice of payloads before +// closing. It marks a test failure if we get a non-close error. +func mockStreamingRpcHandler(t *testing.T, payloads [][]byte) func(io.ReadWriteCloser) { + + return func(conn io.ReadWriteCloser) { + + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // drain any incoming requests + go func() { + for { + select { + case <-ctx.Done(): + return + default: + } + var res cstructs.StreamErrWrapper + err := decoder.Decode(&res) + if !isClosedError(err) { + test.NoError(t, err, test.Sprint("unexpected non-close error")) + } + } + }() + + for _, payload := range payloads { + err := encoder.Encode(cstructs.StreamErrWrapper{Payload: payload}) + test.NoError(t, err, test.Sprint("could not send RPC payload")) + } + test.NoError(t, conn.Close()) + } +}