Skip to content

Commit

Permalink
websocket: add support for dialing with context
Browse files Browse the repository at this point in the history
Right now there is no way to pass context.Context to websocket.Dial.
In addition, this method can block indefinitely in the NewClient call.

Fixes golang/go#57953.

Change-Id: Ic52d4b8306cd0850e78d683abb1bf11f0d4247ca
GitHub-Last-Rev: 5e8c3a7
GitHub-Pull-Request: #160
Reviewed-on: https://go-review.googlesource.com/c/net/+/463097
Auto-Submit: Damien Neil <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Dmitri Shuralyov <[email protected]>
  • Loading branch information
Cyberax authored and gopherbot committed Feb 27, 2024
1 parent fa11427 commit 3dfd003
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 15 deletions.
56 changes: 44 additions & 12 deletions websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package websocket

import (
"bufio"
"context"
"io"
"net"
"net/http"
"net/url"
"time"
)

// DialError is an error that occurs while dialling a websocket server.
Expand Down Expand Up @@ -77,30 +79,60 @@ func parseAuthority(location *url.URL) string {
return location.Host
}

// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
return config.DialContext(context.Background())
}

// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}

dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)

client, err := dialWithDialer(ctx, dialer, config)
if err != nil {
client.Close()
goto Error
return nil, &DialError{config, err}
}
return

Error:
return nil, &DialError{config, err}
// Cleanup the connection if we fail to create the websocket successfully
success := false
defer func() {
if !success {
_ = client.Close()
}
}()

var ws *Conn
var wsErr error
doneConnecting := make(chan struct{})
go func() {
defer close(doneConnecting)
ws, err = NewClient(config, client)
if err != nil {
wsErr = &DialError{config, err}
}
}()

// The websocket.NewClient() function can block indefinitely, make sure that we
// respect the deadlines specified by the context.
select {
case <-ctx.Done():
// Force the pending operations to fail, terminating the pending connection attempt
_ = client.SetDeadline(time.Now())
<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
return nil, &DialError{config, ctx.Err()}
case <-doneConnecting:
if wsErr == nil {
success = true // Disarm the deferred connection cleanup
}
return ws, wsErr
}
}
11 changes: 8 additions & 3 deletions websocket/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
package websocket

import (
"context"
"crypto/tls"
"net"
)

func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location))

case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
tlsDialer := &tls.Dialer{
NetDialer: dialer,
Config: config.TlsConfig,
}

conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location))
default:
err = ErrBadScheme
}
Expand Down
37 changes: 37 additions & 0 deletions websocket/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
package websocket

import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
Expand Down Expand Up @@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) {
t.Fatalf("expected timeout error, got %#v", neterr)
}
}

func TestDialConfigTLSWithTimeouts(t *testing.T) {
t.Parallel()

finishedRequest := make(chan bool)

// Context for cancellation
ctx, cancel := context.WithCancel(context.Background())

// This is a TLS server that blocks each request indefinitely (and cancels the context)
tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cancel()
<-finishedRequest
}))

tlsServerAddr := tlsServer.Listener.Addr().String()
log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
defer tlsServer.Close()
defer close(finishedRequest)

config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
config.TlsConfig = &tls.Config{
InsecureSkipVerify: true,
}

_, err := config.DialContext(ctx)
dialerr, ok := err.(*DialError)
if !ok {
t.Fatalf("DialError expected, got %#v", err)
}
if !errors.Is(dialerr.Err, context.Canceled) {
t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err)
}
}

0 comments on commit 3dfd003

Please sign in to comment.