From dfb69b5d50dee796ce3d6789dc97de2fc8589d65 Mon Sep 17 00:00:00 2001 From: niv Date: Tue, 29 Dec 2020 14:07:35 +0200 Subject: [PATCH] added option for dialer to follow redirects --- dialer.go | 48 ++++++++++++++++++++++++++++++++++++------------ http.go | 2 ++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/dialer.go b/dialer.go index d35dc14..cad5ffd 100644 --- a/dialer.go +++ b/dialer.go @@ -129,6 +129,10 @@ type Dialer struct { // Note that for debugging purposes of an http handshake (e.g. sent request // and received response), there is an wsutil.DebugDialer struct. WrapConn func(conn net.Conn) net.Conn + + // FollowRedirects is a boolean specifying if to follow redirects + // Redirects are any 3xx responses coming from the server + FollowRedirects bool } // Dial connects to the url host and upgrades connection to WebSocket. @@ -188,7 +192,10 @@ func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *buf } br, hs, err = d.Upgrade(conn, u) - + + if e, ok := err.(RedirectError); ok { + return d.Dial(ctx, string(e)) + } return } @@ -332,19 +339,24 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha err = ErrHandshakeBadProtocol return } + var isRedirect bool if resp.status != 101 { - err = StatusError(resp.status) - if onStatusError := d.OnStatusError; onStatusError != nil { - // Invoke callback with multireader of status-line bytes br. - onStatusError(resp.status, resp.reason, - io.MultiReader( - bytes.NewReader(sl), - strings.NewReader(crlf), - br, - ), - ) + if resp.status >= 300 && resp.status < 400 && d.FollowRedirects { + isRedirect = true // Cant return, need to process Location header + } else { + err = StatusError(resp.status) + if onStatusError := d.OnStatusError; onStatusError != nil { + // Invoke callback with multireader of status-line bytes br. + onStatusError(resp.status, resp.reason, + io.MultiReader( + bytes.NewReader(sl), + strings.NewReader(crlf), + br, + ), + ) + } + return } - return } // If response status is 101 then we expect all technical headers to be // valid. If not, then we stop processing response without giving user @@ -369,6 +381,12 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha } switch btsToString(k) { + case headerLocationCanonical: + if d.FollowRedirects && isRedirect { + err = RedirectError(string(v)) + return + } + case headerUpgradeCanonical: headerSeen |= headerSeenUpgrade if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { @@ -457,6 +475,12 @@ func (s StatusError) Error() string { return "unexpected HTTP response status: " + strconv.Itoa(int(s)) } +type RedirectError string + +func (r RedirectError) Error() string { + return "received redirect from server to: " + string(r) +} + func isTimeoutError(err error) bool { t, ok := err.(net.Error) return ok && t.Timeout() diff --git a/http.go b/http.go index 7d7175a..930cf6c 100644 --- a/http.go +++ b/http.go @@ -42,6 +42,7 @@ var ( headerHost = "Host" headerUpgrade = "Upgrade" headerConnection = "Connection" + headerLocation = "Location" headerSecVersion = "Sec-WebSocket-Version" headerSecProtocol = "Sec-WebSocket-Protocol" headerSecExtensions = "Sec-WebSocket-Extensions" @@ -51,6 +52,7 @@ var ( headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost) headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade) headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection) + headerLocationCanonical = textproto.CanonicalMIMEHeaderKey(headerLocation) headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion) headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol) headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)