From 80393295c1185e50d0b784d4bc5ffaa918d187b9 Mon Sep 17 00:00:00 2001 From: LiuYang Date: Thu, 17 Aug 2023 23:25:09 +0800 Subject: [PATCH 1/2] Correct way to save memory using write buffer pool and freeing net.http default buffers (#761) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary of Changes** 1. Add an example that uses the write buffer pool The loop process of the websocket connection is inner the http handler at existing examples, This usage will cause the 8k buffer(4k read buffer + 4k write buffer) allocated by net.http can't be GC(Observed by heap profiling, see picture below) . The purpose of saving memory is not achieved even if the WriteBufferPool is used. In example bufferpool, server process websocket connection in a new goroutine, and the goroutine created by the net.http will exit, then the 8k buffer will be GC. ![heap](https://user-images.githubusercontent.com/12793501/148676918-872d1a6d-ce10-4146-ba01-7de114db09f5.png) Co-authored-by: hakunaliu Co-authored-by: Corey Daley --- README.md | 1 + examples/bufferpool/client.go | 89 +++++++++++++++++++++++++++++++++++ examples/bufferpool/server.go | 55 ++++++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 examples/bufferpool/client.go create mode 100644 examples/bufferpool/server.go diff --git a/README.md b/README.md index d33ed7f..9b6956f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) +* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool) ### Status diff --git a/examples/bufferpool/client.go b/examples/bufferpool/client.go new file mode 100644 index 0000000..a3719a9 --- /dev/null +++ b/examples/bufferpool/client.go @@ -0,0 +1,89 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "flag" + "log" + "net/url" + "os" + "os/signal" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +func runNewConn(wg *sync.WaitGroup) { + defer wg.Done() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"} + log.Printf("connecting to %s", u.String()) + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatal("dial:", err) + } + defer c.Close() + + done := make(chan struct{}) + + go func() { + defer close(done) + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + log.Printf("recv: %s", message) + } + }() + + ticker := time.NewTicker(time.Minute * 5) + defer ticker.Stop() + + for { + select { + case <-done: + return + case t := <-ticker.C: + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + return + } + case <-interrupt: + log.Println("interrupt") + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Println("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + return + } + } +} + +func main() { + flag.Parse() + log.SetFlags(0) + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go runNewConn(wg) + } + wg.Wait() +} diff --git a/examples/bufferpool/server.go b/examples/bufferpool/server.go new file mode 100644 index 0000000..25bb20f --- /dev/null +++ b/examples/bufferpool/server.go @@ -0,0 +1,55 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "flag" + "log" + "net/http" + "sync" + + _ "net/http/pprof" + + "github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 256, + WriteBufferSize: 256, + WriteBufferPool: &sync.Pool{}, +} + +func process(c *websocket.Conn) { + defer c.Close() + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("recv: %s", message) + } +} + +func handler(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + + // Process connection in a new goroutine + go process(c) + + // Let the http handler return, the 8k buffer created by it will be garbage collected +} + +func main() { + flag.Parse() + log.SetFlags(0) + http.HandleFunc("/ws", handler) + log.Fatal(http.ListenAndServe(*addr, nil)) +} From 666c197fc9157896b57515c3a3326c3f8c8319fe Mon Sep 17 00:00:00 2001 From: Corey Daley Date: Sat, 26 Aug 2023 16:01:45 -0400 Subject: [PATCH 2/2] Update go version & add verification/testing tools (#840) Fixes # **Summary of Changes** 1. 2. 3. > PS: Make sure your PR includes/updates tests! If you need help with this part, just ask! --- .circleci/config.yml | 70 --- .editorconfig | 20 + .github/release-drafter.yml | 7 - .github/workflows/issues.yml | 21 + .github/workflows/test.yml | 55 ++ .gitignore | 26 +- .golangci.yml | 3 + AUTHORS | 9 - LICENSE | 39 +- Makefile | 34 ++ README.md | 14 +- client.go | 26 +- client_server_test.go | 28 +- compression.go | 9 +- compression_test.go | 17 +- conn.go | 75 ++- conn_broadcast_test.go | 11 +- conn_test.go | 75 ++- examples/autobahn/server.go | 6 +- examples/chat/main.go | 7 +- examples/command/main.go | 6 +- examples/filewatch/main.go | 9 +- go.mod | 4 +- go.sum | 2 + join_test.go | 4 +- mask.go | 4 + prepared_test.go | 16 +- proxy.go | 17 +- server.go | 38 +- tls_handshake.go | 3 - tls_handshake_116.go | 21 - util.go | 4 +- vendor/golang.org/x/net/LICENSE | 27 + vendor/golang.org/x/net/PATENTS | 22 + .../golang.org/x/net/internal/socks/client.go | 168 +++++++ .../golang.org/x/net/internal/socks/socks.go | 317 ++++++++++++ vendor/golang.org/x/net/proxy/dial.go | 54 ++ vendor/golang.org/x/net/proxy/direct.go | 31 ++ vendor/golang.org/x/net/proxy/per_host.go | 155 ++++++ vendor/golang.org/x/net/proxy/proxy.go | 149 ++++++ vendor/golang.org/x/net/proxy/socks5.go | 42 ++ vendor/modules.txt | 4 + x_net_proxy.go | 473 ------------------ 43 files changed, 1388 insertions(+), 734 deletions(-) delete mode 100644 .circleci/config.yml create mode 100644 .editorconfig delete mode 100644 .github/release-drafter.yml create mode 100644 .github/workflows/issues.yml create mode 100644 .github/workflows/test.yml create mode 100644 .golangci.yml delete mode 100644 AUTHORS create mode 100644 Makefile delete mode 100644 tls_handshake_116.go create mode 100644 vendor/golang.org/x/net/LICENSE create mode 100644 vendor/golang.org/x/net/PATENTS create mode 100644 vendor/golang.org/x/net/internal/socks/client.go create mode 100644 vendor/golang.org/x/net/internal/socks/socks.go create mode 100644 vendor/golang.org/x/net/proxy/dial.go create mode 100644 vendor/golang.org/x/net/proxy/direct.go create mode 100644 vendor/golang.org/x/net/proxy/per_host.go create mode 100644 vendor/golang.org/x/net/proxy/proxy.go create mode 100644 vendor/golang.org/x/net/proxy/socks5.go create mode 100644 vendor/modules.txt delete mode 100644 x_net_proxy.go diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index ecb33f6..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,70 +0,0 @@ -version: 2.1 - -jobs: - "test": - parameters: - version: - type: string - default: "latest" - golint: - type: boolean - default: true - modules: - type: boolean - default: true - goproxy: - type: string - default: "" - docker: - - image: "cimg/go:<< parameters.version >>" - working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket - environment: - GO111MODULE: "on" - GOPROXY: "<< parameters.goproxy >>" - steps: - - checkout - - run: - name: "Print the Go version" - command: > - go version - - run: - name: "Fetch dependencies" - command: > - if [[ << parameters.modules >> = true ]]; then - go mod download - export GO111MODULE=on - else - go get -v ./... - fi - # Only run gofmt, vet & lint against the latest Go version - - run: - name: "Run golint" - command: > - if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then - go get -u golang.org/x/lint/golint - golint ./... - fi - - run: - name: "Run gofmt" - command: > - if [[ << parameters.version >> = "latest" ]]; then - diff -u <(echo -n) <(gofmt -d -e .) - fi - - run: - name: "Run go vet" - command: > - if [[ << parameters.version >> = "latest" ]]; then - go vet -v ./... - fi - - run: - name: "Run go test (+ race detector)" - command: > - go test -v -race ./... - -workflows: - tests: - jobs: - - test: - matrix: - parameters: - version: ["1.18", "1.17", "1.16"] diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..2940ec9 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,20 @@ +; https://editorconfig.org/ + +root = true + +[*] +insert_final_newline = true +charset = utf-8 +trim_trailing_whitespace = true +indent_style = space +indent_size = 2 + +[{Makefile,go.mod,go.sum,*.go,.gitmodules}] +indent_style = tab +indent_size = 4 + +[*.md] +indent_size = 4 +trim_trailing_whitespace = false + +eclint_indent_style = unset diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml deleted file mode 100644 index 0986b3e..0000000 --- a/.github/release-drafter.yml +++ /dev/null @@ -1,7 +0,0 @@ -# Config for https://github.com/apps/release-drafter -template: | - - - - ## CHANGELOG - $CHANGES diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml new file mode 100644 index 0000000..2a38d75 --- /dev/null +++ b/.github/workflows/issues.yml @@ -0,0 +1,21 @@ +# Add issues or pull-requests created to the project. +name: Add issue or pull request to Project + +on: + issues: + types: + - opened + pull_request_target: + types: + - opened + - reopened + +jobs: + add-to-project: + runs-on: ubuntu-latest + steps: + - name: Add issue to project + uses: actions/add-to-project@v0.5.0 + with: + project-url: https://github.com/orgs/gorilla/projects/4 + github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a804cc8 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,55 @@ +name: CI +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + verify-and-test: + strategy: + matrix: + go: ['1.19','1.20'] + os: [ubuntu-latest, macos-latest, windows-latest] + fail-fast: true + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Setup Go ${{ matrix.go }} + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + cache: false + + - name: Run GolangCI-Lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.53 + args: --timeout=5m + + - name: Run GoSec + if: matrix.os == 'ubuntu-latest' + uses: securego/gosec@master + with: + args: -exclude-dir examples ./... + + - name: Run GoVulnCheck + uses: golang/govulncheck-action@v1 + with: + go-version-input: ${{ matrix.go }} + go-package: ./... + + - name: Run Tests + run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./... + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./coverage diff --git a/.gitignore b/.gitignore index cd3fcd1..84039fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,25 +1 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe - -.idea/ -*.iml +coverage.coverprofile diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..3488213 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,3 @@ +run: + skip-dirs: + - examples/*.go diff --git a/AUTHORS b/AUTHORS deleted file mode 100644 index 1931f40..0000000 --- a/AUTHORS +++ /dev/null @@ -1,9 +0,0 @@ -# This is the official list of Gorilla WebSocket authors for copyright -# purposes. -# -# Please keep the list sorted. - -Gary Burd -Google LLC (https://opensource.google.com/) -Joachim Bauch - diff --git a/LICENSE b/LICENSE index 9171c97..bb9d80b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,22 +1,27 @@ -Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. +Copyright (c) 2023 The Gorilla Authors. All rights reserved. Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: +modification, are permitted provided that the following conditions are +met: - Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. - Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..603a63f --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '') +GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +GO_SEC=$(shell which gosec 2> /dev/null || echo '') +GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest + +GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '') +GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest + +.PHONY: golangci-lint +golangci-lint: + $(if $(GO_LINT), ,go install $(GO_LINT_URI)) + @echo "##### Running golangci-lint" + golangci-lint run -v + +.PHONY: gosec +gosec: + $(if $(GO_SEC), ,go install $(GO_SEC_URI)) + @echo "##### Running gosec" + gosec -exclude-dir examples ./... + +.PHONY: govulncheck +govulncheck: + $(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI)) + @echo "##### Running govulncheck" + govulncheck ./... + +.PHONY: verify +verify: golangci-lint gosec govulncheck + +.PHONY: test +test: + @echo "##### Running tests" + go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./... diff --git a/README.md b/README.md index 9b6956f..1fd5e9c 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ -# Gorilla WebSocket +# gorilla/websocket -[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) -[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) +![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg) +[![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket) +[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge) -Gorilla WebSocket is a [Go](http://golang.org/) implementation of the -[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. +Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5) ### Documentation @@ -31,4 +34,3 @@ package API is stable. The Gorilla WebSocket package passes the server tests in the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). - diff --git a/client.go b/client.go index 04fdafe..815b0ca 100644 --- a/client.go +++ b/client.go @@ -11,13 +11,16 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "log" + "net" "net/http" "net/http/httptrace" "net/url" "strings" "time" + + "golang.org/x/net/proxy" ) // ErrBadHandshake is returned when the server response to opening handshake is @@ -225,6 +228,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || + //#nosec G101 (CWE-798): Potential HTTP request smuggling via parameter pollution k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) @@ -290,7 +294,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } err = c.SetDeadline(deadline) if err != nil { - c.Close() + if err := c.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } return nil, err } return c, nil @@ -304,7 +310,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return nil, nil, err } if proxyURL != nil { - dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial)) if err != nil { return nil, nil, err } @@ -330,7 +336,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h defer func() { if netConn != nil { - netConn.Close() + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } } }() @@ -400,7 +408,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // debugging. buf := make([]byte, 1024) n, _ := io.ReadFull(resp.Body, buf) - resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) return nil, resp, ErrBadHandshake } @@ -418,17 +426,19 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h break } - resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - netConn.SetDeadline(time.Time{}) + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, nil, err + } netConn = nil // to avoid close in defer. return conn, resp, nil } func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { - return &tls.Config{} + return &tls.Config{MinVersion: tls.VersionTLS12} } return cfg.Clone() } diff --git a/client_server_test.go b/client_server_test.go index a47df48..6095c09 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/http" @@ -50,7 +49,6 @@ type cstHandler struct{ *testing.T } type cstServer struct { *httptest.Server URL string - t *testing.T } const ( @@ -94,7 +92,6 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { - t.Logf("Upgrade: %v", err) return } defer ws.Close() @@ -106,20 +103,16 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } op, rd, err := ws.NextReader() if err != nil { - t.Logf("NextReader: %v", err) return } wr, err := ws.NextWriter(op) if err != nil { - t.Logf("NextWriter: %v", err) return } if _, err = io.Copy(wr, rd); err != nil { - t.Logf("NextWriter: %v", err) return } if err := wr.Close(); err != nil { - t.Logf("Close: %v", err) return } } @@ -531,7 +524,9 @@ func TestRespOnBadHandshake(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - io.WriteString(w, expectedBody) + if _, err := io.WriteString(w, expectedBody); err != nil { + t.Fatalf("WriteString: %v", err) + } })) defer s.Close() @@ -549,7 +544,7 @@ func TestRespOnBadHandshake(t *testing.T) { t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) } - p, err := ioutil.ReadAll(resp.Body) + p, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("ReadFull(resp.Body) returned error %v", err) } @@ -564,7 +559,6 @@ type testLogWriter struct { } func (w testLogWriter) Write(p []byte) (int, error) { - w.t.Logf("%s", p) return len(p), nil } @@ -781,7 +775,10 @@ func TestSocksProxyDial(t *testing.T) { } defer c1.Close() - c1.SetDeadline(time.Now().Add(30 * time.Second)) + if err := c1.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { + t.Errorf("set deadline failed: %v", err) + return + } buf := make([]byte, 32) if _, err := io.ReadFull(c1, buf[:3]); err != nil { @@ -820,10 +817,15 @@ func TestSocksProxyDial(t *testing.T) { defer c2.Close() done := make(chan struct{}) go func() { - io.Copy(c1, c2) + if _, err := io.Copy(c1, c2); err != nil { + t.Errorf("copy failed: %v", err) + } close(done) }() - io.Copy(c2, c1) + if _, err := io.Copy(c2, c1); err != nil { + t.Errorf("copy failed: %v", err) + return + } <-done }() diff --git a/compression.go b/compression.go index 813ffb1..9fed0ef 100644 --- a/compression.go +++ b/compression.go @@ -8,6 +8,7 @@ import ( "compress/flate" "errors" "io" + "log" "strings" "sync" ) @@ -33,7 +34,9 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil { + panic(err) + } return &flateReadWrapper{fr} } @@ -132,7 +135,9 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after // this final read. - r.Close() + if err := r.Close(); err != nil { + log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err) + } } return n, err } diff --git a/compression_test.go b/compression_test.go index 8a26b30..80cbc4e 100644 --- a/compression_test.go +++ b/compression_test.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "testing" ) @@ -23,7 +22,9 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - w.Write(p[:m]) + if _, err := w.Write(p[:m]); err != nil { + t.Fatal(err) + } p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -42,25 +43,29 @@ func textMessages(num int) [][]byte { } func BenchmarkWriteNoCompression(b *testing.B) { - w := ioutil.Discard + w := io.Discard c := newTestConn(nil, w, false) messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil { + b.Fatal(err) + } } b.ReportAllocs() } func BenchmarkWriteWithCompression(b *testing.B) { - w := ioutil.Discard + w := io.Discard c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil { + b.Fatal(err) + } } b.ReportAllocs() } diff --git a/conn.go b/conn.go index 5161ef8..221e6cf 100644 --- a/conn.go +++ b/conn.go @@ -6,11 +6,11 @@ package websocket import ( "bufio" + "crypto/rand" "encoding/binary" "errors" "io" - "io/ioutil" - "math/rand" + "log" "net" "strconv" "strings" @@ -181,13 +181,20 @@ var ( errInvalidControlFrame = errors.New("websocket: invalid control frame") ) +// maskRand is an io.Reader for generating mask bytes. The reader is initialized +// to crypto/rand Reader. Tests swap the reader to a math/rand reader for +// reproducible results. +var maskRand = rand.Reader + +// newMaskKey returns a new 32 bit value for masking client frames. func newMaskKey() [4]byte { - n := rand.Uint32() - return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} + var k [4]byte + _, _ = io.ReadFull(maskRand, k[:]) + return k } func hideTempErr(err error) error { - if e, ok := err.(net.Error); ok && e.Temporary() { + if e, ok := err.(net.Error); ok { err = &netError{msg: e.Error(), timeout: e.Timeout()} } return err @@ -372,7 +379,9 @@ func (c *Conn) read(n int) ([]byte, error) { if err == io.EOF { err = errUnexpectedEOF } - c.br.Discard(len(p)) + if _, err := c.br.Discard(len(p)); err != nil { + return p, err + } return p, err } @@ -387,7 +396,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return err } - c.conn.SetWriteDeadline(deadline) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { @@ -397,7 +408,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return c.writeFatal(err) } if frameType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return nil } @@ -438,7 +449,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er d := 1000 * time.Hour if !deadline.IsZero() { - d = deadline.Sub(time.Now()) + d = time.Until(deadline) if d < 0 { return errWriteTimeout } @@ -460,13 +471,15 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - c.conn.SetWriteDeadline(deadline) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } _, err = c.conn.Write(buf) if err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return err } @@ -477,7 +490,9 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. if c.writer != nil { - c.writer.Close() + if err := c.writer.Close(); err != nil { + log.Printf("websocket: discarding writer close error: %v", err) + } c.writer = nil } @@ -630,7 +645,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if final { - w.endMessage(errWriteClosed) + _ = w.endMessage(errWriteClosed) return nil } @@ -795,7 +810,7 @@ func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. if c.readRemaining > 0 { - if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { return noFrame, err } } @@ -817,7 +832,9 @@ func (c *Conn) advanceFrame() (int, error) { rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.setReadRemaining(int64(p[1] & 0x7f)) + if err := c.setReadRemaining(int64(p[1] & 0x7f)); err != nil { + return noFrame, err + } c.readDecompress = false if rsv1 { @@ -922,7 +939,9 @@ func (c *Conn) advanceFrame() (int, error) { } if c.readLimit > 0 && c.readLength > c.readLimit { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + if err := c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)); err != nil { + return noFrame, err + } return noFrame, ErrReadLimit } @@ -934,7 +953,9 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.setReadRemaining(0) + if err := c.setReadRemaining(0); err != nil { + return noFrame, err + } if err != nil { return noFrame, err } @@ -981,7 +1002,9 @@ func (c *Conn) handleProtocolError(message string) error { if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } - c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + if err := c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)); err != nil { + return err + } return errors.New("websocket: " + message) } @@ -998,7 +1021,9 @@ func (c *Conn) handleProtocolError(message string) error { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { // Close previous reader, only relevant for decompression. if c.reader != nil { - c.reader.Close() + if err := c.reader.Close(); err != nil { + log.Printf("websocket: discarding reader close error: %v", err) + } c.reader = nil } @@ -1054,7 +1079,9 @@ func (r *messageReader) Read(b []byte) (int, error) { } rem := c.readRemaining rem -= int64(n) - c.setReadRemaining(rem) + if err := c.setReadRemaining(rem); err != nil { + return 0, err + } if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1094,7 +1121,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { if err != nil { return messageType, nil, err } - p, err = ioutil.ReadAll(r) + p, err = io.ReadAll(r) return messageType, p, err } @@ -1136,7 +1163,9 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil { + return err + } return nil } } @@ -1161,7 +1190,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) { err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) if err == ErrCloseSent { return nil - } else if e, ok := err.(net.Error); ok && e.Temporary() { + } else if _, ok := err.(net.Error); ok { return nil } return err diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index 6e744fc..f63f62f 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -6,7 +6,6 @@ package websocket import ( "io" - "io/ioutil" "sync/atomic" "testing" ) @@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn { func newBroadcastBench(usePrepared, compression bool) *broadcastBench { bench := &broadcastBench{ - w: ioutil.Discard, + w: io.Discard, doneCh: make(chan struct{}), closeCh: make(chan struct{}), usePrepared: usePrepared, @@ -70,9 +69,13 @@ func (b *broadcastBench) makeConns(numConns int) { select { case msg := <-c.msgCh: if msg.prepared != nil { - c.conn.WritePreparedMessage(msg.prepared) + if err := c.conn.WritePreparedMessage(msg.prepared); err != nil { + panic(err) + } } else { - c.conn.WriteMessage(TextMessage, msg.payload) + if err := c.conn.WriteMessage(TextMessage, msg.payload); err != nil { + panic(err) + } } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { diff --git a/conn_test.go b/conn_test.go index 06e5184..2b823dd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "reflect" "sync" @@ -124,8 +123,7 @@ func TestFraming(t *testing.T) { continue } - t.Logf("frame size: %d", n) - rbuf, err := ioutil.ReadAll(r) + rbuf, err := io.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) continue @@ -158,7 +156,10 @@ func TestControl(t *testing.T) { wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("%s: wc.WriteControl() returned %v", name, err) + continue + } } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -175,7 +176,9 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() + if _, _, err := rc.NextReader(); err != nil { + continue + } if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -359,15 +362,19 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil { + t.Fatalf("w.Write() returned %v", err) + } + if err := wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)); err != nil { + t.Fatalf("wc.WriteControl() returned %v", err) + } w.Close() op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if !reflect.DeepEqual(err, expectedErr) { t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) } @@ -386,7 +393,9 @@ func TestEOFWithinFrame(t *testing.T) { rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) + if _, err := w.Write(make([]byte, bufSize)); err != nil { + t.Fatalf("%d: w.Write() returned %v", n, err) + } w.Close() if n >= b.Len() { @@ -401,7 +410,7 @@ func TestEOFWithinFrame(t *testing.T) { if op != BinaryMessage || err != nil { t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) } @@ -420,13 +429,15 @@ func TestEOFBeforeFinalFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) + if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil { + t.Fatalf("w.Write() returned %v", err) + } op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } @@ -439,7 +450,9 @@ func TestEOFBeforeFinalFrame(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) { wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + if _, err := io.WriteString(w, "hello"); err != nil { + t.Fatalf("unexpected error writing, %v", err) + } if err := w.Close(); err != nil { t.Fatalf("unxpected error closing message writer, %v", err) } @@ -449,7 +462,9 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + if _, err := io.WriteString(w, "hello"); err != nil { + t.Fatalf("unexpected error writing after getting new writer, %v", err) + } // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -474,13 +489,21 @@ func TestReadLimit(t *testing.T) { // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) + if _, err := w.Write(message[:readLimit-1]); err != nil { + t.Fatalf("w.WriteMessage() returned %v", err) + } + if err := wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)); err != nil { + t.Fatalf("wc.WriteControl() returned %v", err) + } + if _, err := w.Write(message[:1]); err != nil { + t.Fatalf("w.Write() returned %v", err) + } w.Close() // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + if err := wc.WriteMessage(BinaryMessage, message[:readLimit+1]); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -490,7 +513,7 @@ func TestReadLimit(t *testing.T) { if op != BinaryMessage || err != nil { t.Fatalf("2: NextReader() returned %d, %v", op, err) } - _, err = io.Copy(ioutil.Discard, r) + _, err = io.Copy(io.Discard, r) if err != ErrReadLimit { t.Fatalf("io.Copy() returned %v", err) } @@ -593,7 +616,9 @@ func TestBufioReadBytes(t *testing.T) { rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) + if _, err := w.Write(m); err != nil { + t.Fatalf("w.Write() returned %v", err) + } w.Close() op, r, err := rc.NextReader() @@ -667,7 +692,9 @@ func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { - c.WriteMessage(TextMessage, []byte{}) + if err := c.WriteMessage(TextMessage, []byte{}); err != nil { + t.Error(err) + } }() // wait for goroutine to block in write. @@ -680,7 +707,9 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - c.WriteMessage(TextMessage, []byte{}) + if err := c.WriteMessage(TextMessage, []byte{}); err != nil { + t.Error(err) + } t.Fatal("should not get here") } @@ -700,7 +729,7 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - c.ReadMessage() + _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index 8b17fe3..1cd273f 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -178,7 +178,11 @@ func main() { http.HandleFunc("/r", echoReadAllWriter) http.HandleFunc("/m", echoReadAllWriteMessage) http.HandleFunc("/p", echoReadAllWritePreparedMessage) - err := http.ListenAndServe(*addr, nil) + server := &http.Server{ + Addr: *addr, + ReadHeaderTimeout: 3 * time.Second, + } + err := server.ListenAndServe() if err != nil { log.Fatal("ListenAndServe: ", err) } diff --git a/examples/chat/main.go b/examples/chat/main.go index 474709f..591cb89 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -8,6 +8,7 @@ import ( "flag" "log" "net/http" + "time" ) var addr = flag.String("addr", ":8080", "http service address") @@ -33,7 +34,11 @@ func main() { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { serveWs(hub, w, r) }) - err := http.ListenAndServe(*addr, nil) + server := &http.Server{ + Addr: *addr, + ReadHeaderTimeout: 3 * time.Second, + } + err := server.ListenAndServe() if err != nil { log.Fatal("ListenAndServe: ", err) } diff --git a/examples/command/main.go b/examples/command/main.go index 38d9f6c..11b52e5 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -189,5 +189,9 @@ func main() { } http.HandleFunc("/", serveHome) http.HandleFunc("/ws", serveWs) - log.Fatal(http.ListenAndServe(*addr, nil)) + server := &http.Server{ + Addr: *addr, + ReadHeaderTimeout: 3 * time.Second, + } + log.Fatal(server.ListenAndServe()) } diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index d4bf80e..ddb613c 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -11,6 +11,7 @@ import ( "log" "net/http" "os" + "path/filepath" "strconv" "time" @@ -49,7 +50,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { if !fi.ModTime().After(lastMod) { return nil, lastMod, nil } - p, err := ioutil.ReadFile(filename) + p, err := ioutil.ReadFile(filepath.Clean(filename)) if err != nil { return nil, fi.ModTime(), err } @@ -163,7 +164,11 @@ func main() { filename = flag.Args()[0] http.HandleFunc("/", serveHome) http.HandleFunc("/ws", serveWs) - if err := http.ListenAndServe(*addr, nil); err != nil { + server := &http.Server{ + Addr: *addr, + ReadHeaderTimeout: 3 * time.Second, + } + if err := server.ListenAndServe(); err != nil { log.Fatal(err) } } diff --git a/go.mod b/go.mod index 1a7afd5..c5c4f92 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/gorilla/websocket -go 1.12 +go 1.19 + +require golang.org/x/net v0.14.0 diff --git a/go.sum b/go.sum index e69de29..f888ee0 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= diff --git a/join_test.go b/join_test.go index 961ac04..f06601d 100644 --- a/join_test.go +++ b/join_test.go @@ -19,7 +19,9 @@ func TestJoinMessages(t *testing.T) { wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { - wc.WriteMessage(BinaryMessage, []byte(m)) + if err := wc.WriteMessage(BinaryMessage, []byte(m)); err != nil { + t.Fatalf("write %q: %v", m, err) + } } var result bytes.Buffer diff --git a/mask.go b/mask.go index d0742bf..67d0968 100644 --- a/mask.go +++ b/mask.go @@ -9,6 +9,7 @@ package websocket import "unsafe" +// #nosec G103 -- (CWE-242) Has been audited const wordSize = int(unsafe.Sizeof(uintptr(0))) func maskBytes(key [4]byte, pos int, b []byte) int { @@ -22,6 +23,7 @@ func maskBytes(key [4]byte, pos int, b []byte) int { } // Mask one byte at a time to word boundary. + //#nosec G103 -- (CWE-242) Has been audited if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { n = wordSize - n for i := range b[:n] { @@ -36,11 +38,13 @@ func maskBytes(key [4]byte, pos int, b []byte) int { for i := range k { k[i] = key[(pos+i)&3] } + //#nosec G103 -- (CWE-242) Has been audited kw := *(*uintptr)(unsafe.Pointer(&k)) // Mask one word at a time. n := (len(b) / wordSize) * wordSize for i := 0; i < n; i += wordSize { + //#nosec G103 -- (CWE-242) Has been audited *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw } diff --git a/prepared_test.go b/prepared_test.go index 2297802..dc77ef0 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -33,6 +33,11 @@ var preparedMessageTests = []struct { } func TestPreparedMessage(t *testing.T) { + testRand := rand.New(rand.NewSource(99)) + prevMaskRand := maskRand + maskRand = testRand + defer func() { maskRand = prevMaskRand }() + for _, tt := range preparedMessageTests { var data = []byte("this is a test") var buf bytes.Buffer @@ -40,10 +45,13 @@ func TestPreparedMessage(t *testing.T) { if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } - c.SetCompressionLevel(tt.compressionLevel) + + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } // Seed random number generator for consistent frame mask. - rand.Seed(1234) + testRand.Seed(1234) if err := c.WriteMessage(tt.messageType, data); err != nil { t.Fatal(err) @@ -59,7 +67,7 @@ func TestPreparedMessage(t *testing.T) { copy(data, "hello world") // Seed random number generator for consistent frame mask. - rand.Seed(1234) + testRand.Seed(1234) buf.Reset() if err := c.WritePreparedMessage(pm); err != nil { @@ -68,7 +76,7 @@ func TestPreparedMessage(t *testing.T) { got := buf.String() if got != want { - t.Errorf("write message != prepared message for %+v", tt) + t.Errorf("write message != prepared message, got %#v, want %#v", got, want) } } } diff --git a/proxy.go b/proxy.go index e0f466b..80f55d1 100644 --- a/proxy.go +++ b/proxy.go @@ -8,10 +8,13 @@ import ( "bufio" "encoding/base64" "errors" + "log" "net" "net/http" "net/url" "strings" + + "golang.org/x/net/proxy" ) type netDialerFunc func(network, addr string) (net.Conn, error) @@ -21,7 +24,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { } func init() { - proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) { return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil }) } @@ -55,7 +58,9 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } if err := connectReq.Write(conn); err != nil { - conn.Close() + if err := conn.Close(); err != nil { + log.Printf("httpProxyDialer: failed to close connection: %v", err) + } return nil, err } @@ -64,12 +69,16 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) br := bufio.NewReader(conn) resp, err := http.ReadResponse(br, connectReq) if err != nil { - conn.Close() + if err := conn.Close(); err != nil { + log.Printf("httpProxyDialer: failed to close connection: %v", err) + } return nil, err } if resp.StatusCode != 200 { - conn.Close() + if err := conn.Close(); err != nil { + log.Printf("httpProxyDialer: failed to close connection: %v", err) + } f := strings.SplitN(resp.Status, " ", 2) return nil, errors.New(f[1]) } diff --git a/server.go b/server.go index bb33597..1e720e1 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "bufio" "errors" "io" + "log" "net/http" "net/url" "strings" @@ -183,7 +184,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } if brw.Reader.Buffered() > 0 { - netConn.Close() + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } return nil, errors.New("websocket: client sent data before handshake is complete") } @@ -248,17 +251,34 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, "\r\n"...) // Clear deadlines set by HTTP server. - netConn.SetDeadline(time.Time{}) + if err := netConn.SetDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } } if _, err = netConn.Write(p); err != nil { - netConn.Close() + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } return nil, err } if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Time{}) + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + if err := netConn.Close(); err != nil { + log.Printf("websocket: failed to close network connection: %v", err) + } + return nil, err + } } return c, nil @@ -356,8 +376,12 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { // bufio.Writer's underlying writer. var wh writeHook bw.Reset(&wh) - bw.WriteByte(0) - bw.Flush() + if err := bw.WriteByte(0); err != nil { + panic(err) + } + if err := bw.Flush(); err != nil { + log.Printf("websocket: bufioWriterBuffer: Flush: %v", err) + } bw.Reset(originalWriter) diff --git a/tls_handshake.go b/tls_handshake.go index a62b68c..7f38645 100644 --- a/tls_handshake.go +++ b/tls_handshake.go @@ -1,6 +1,3 @@ -//go:build go1.17 -// +build go1.17 - package websocket import ( diff --git a/tls_handshake_116.go b/tls_handshake_116.go deleted file mode 100644 index e1b2b44..0000000 --- a/tls_handshake_116.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !go1.17 -// +build !go1.17 - -package websocket - -import ( - "context" - "crypto/tls" -) - -func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.Handshake(); err != nil { - return err - } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } - } - return nil -} diff --git a/util.go b/util.go index 31a5dee..9b1a629 100644 --- a/util.go +++ b/util.go @@ -6,7 +6,7 @@ package websocket import ( "crypto/rand" - "crypto/sha1" + "crypto/sha1" //#nosec G505 -- (CWE-327) https://datatracker.ietf.org/doc/html/rfc6455#page-54 "encoding/base64" "io" "net/http" @@ -17,7 +17,7 @@ import ( var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func computeAcceptKey(challengeKey string) string { - h := sha1.New() + h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 h.Write([]byte(challengeKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) diff --git a/vendor/golang.org/x/net/LICENSE b/vendor/golang.org/x/net/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/vendor/golang.org/x/net/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/net/PATENTS b/vendor/golang.org/x/net/PATENTS new file mode 100644 index 0000000..7330990 --- /dev/null +++ b/vendor/golang.org/x/net/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/net/internal/socks/client.go b/vendor/golang.org/x/net/internal/socks/client.go new file mode 100644 index 0000000..3d6f516 --- /dev/null +++ b/vendor/golang.org/x/net/internal/socks/client.go @@ -0,0 +1,168 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "time" +) + +var ( + noDeadline = time.Time{} + aLongTimeAgo = time.Unix(1, 0) +) + +func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { + host, port, err := splitHostPort(address) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + c.SetDeadline(deadline) + defer c.SetDeadline(noDeadline) + } + if ctx != context.Background() { + errCh := make(chan error, 1) + done := make(chan struct{}) + defer func() { + close(done) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + go func() { + select { + case <-ctx.Done(): + c.SetDeadline(aLongTimeAgo) + errCh <- ctx.Err() + case <-done: + errCh <- nil + } + }() + } + + b := make([]byte, 0, 6+len(host)) // the size here is just an estimate + b = append(b, Version5) + if len(d.AuthMethods) == 0 || d.Authenticate == nil { + b = append(b, 1, byte(AuthMethodNotRequired)) + } else { + ams := d.AuthMethods + if len(ams) > 255 { + return nil, errors.New("too many authentication methods") + } + b = append(b, byte(len(ams))) + for _, am := range ams { + b = append(b, byte(am)) + } + } + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + am := AuthMethod(b[1]) + if am == AuthMethodNoAcceptableMethods { + return nil, errors.New("no acceptable authentication methods") + } + if d.Authenticate != nil { + if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { + return + } + } + + b = b[:0] + b = append(b, Version5, byte(d.cmd), 0) + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + b = append(b, AddrTypeIPv4) + b = append(b, ip4...) + } else if ip6 := ip.To16(); ip6 != nil { + b = append(b, AddrTypeIPv6) + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + } else { + if len(host) > 255 { + return nil, errors.New("FQDN too long") + } + b = append(b, AddrTypeFQDN) + b = append(b, byte(len(host))) + b = append(b, host...) + } + b = append(b, byte(port>>8), byte(port)) + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { + return nil, errors.New("unknown error " + cmdErr.String()) + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + l := 2 + var a Addr + switch b[3] { + case AddrTypeIPv4: + l += net.IPv4len + a.IP = make(net.IP, net.IPv4len) + case AddrTypeIPv6: + l += net.IPv6len + a.IP = make(net.IP, net.IPv6len) + case AddrTypeFQDN: + if _, err := io.ReadFull(c, b[:1]); err != nil { + return nil, err + } + l += int(b[0]) + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + } + if cap(b) < l { + b = make([]byte, l) + } else { + b = b[:l] + } + if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { + return + } + if a.IP != nil { + copy(a.IP, b) + } else { + a.Name = string(b[:len(b)-2]) + } + a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) + return &a, nil +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} diff --git a/vendor/golang.org/x/net/internal/socks/socks.go b/vendor/golang.org/x/net/internal/socks/socks.go new file mode 100644 index 0000000..84fcc32 --- /dev/null +++ b/vendor/golang.org/x/net/internal/socks/socks.go @@ -0,0 +1,317 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package socks provides a SOCKS version 5 client implementation. +// +// SOCKS protocol version 5 is defined in RFC 1928. +// Username/Password authentication for SOCKS version 5 is defined in +// RFC 1929. +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" +) + +// A Command represents a SOCKS command. +type Command int + +func (cmd Command) String() string { + switch cmd { + case CmdConnect: + return "socks connect" + case cmdBind: + return "socks bind" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +// An AuthMethod represents a SOCKS authentication method. +type AuthMethod int + +// A Reply represents a SOCKS command reply code. +type Reply int + +func (code Reply) String() string { + switch code { + case StatusSucceeded: + return "succeeded" + case 0x01: + return "general SOCKS server failure" + case 0x02: + return "connection not allowed by ruleset" + case 0x03: + return "network unreachable" + case 0x04: + return "host unreachable" + case 0x05: + return "connection refused" + case 0x06: + return "TTL expired" + case 0x07: + return "command not supported" + case 0x08: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// Wire protocol constants. +const ( + Version5 = 0x05 + + AddrTypeIPv4 = 0x01 + AddrTypeFQDN = 0x03 + AddrTypeIPv6 = 0x04 + + CmdConnect Command = 0x01 // establishes an active-open forward proxy connection + cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + + AuthMethodNotRequired AuthMethod = 0x00 // no authentication required + AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password + AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods + + StatusSucceeded Reply = 0x00 +) + +// An Addr represents a SOCKS-specific address. +// Either Name or IP is used exclusively. +type Addr struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *Addr) Network() string { return "socks" } + +func (a *Addr) String() string { + if a == nil { + return "" + } + port := strconv.Itoa(a.Port) + if a.IP == nil { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +// A Conn represents a forward proxy connection. +type Conn struct { + net.Conn + + boundAddr net.Addr +} + +// BoundAddr returns the address assigned by the proxy server for +// connecting to the command target address from the proxy server. +func (c *Conn) BoundAddr() net.Addr { + if c == nil { + return nil + } + return c.boundAddr +} + +// A Dialer holds SOCKS-specific options. +type Dialer struct { + cmd Command // either CmdConnect or cmdBind + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address + + // ProxyDial specifies the optional dial function for + // establishing the transport connection. + ProxyDial func(context.Context, string, string) (net.Conn, error) + + // AuthMethods specifies the list of request authentication + // methods. + // If empty, SOCKS client requests only AuthMethodNotRequired. + AuthMethods []AuthMethod + + // Authenticate specifies the optional authentication + // function. It must be non-nil when AuthMethods is not empty. + // It must return an error when the authentication is failed. + Authenticate func(context.Context, io.ReadWriter, AuthMethod) error +} + +// DialContext connects to the provided address on the provided +// network. +// +// The returned error value may be a net.OpError. When the Op field of +// net.OpError contains "socks", the Source field contains a proxy +// server address and the Addr field contains a command target +// address. +// +// See func Dial of the net package of standard library for a +// description of the network and address parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) + } else { + var dd net.Dialer + c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + a, err := d.connect(ctx, c, address) + if err != nil { + c.Close() + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return &Conn{Conn: c, boundAddr: a}, nil +} + +// DialWithConn initiates a connection from SOCKS server to the target +// network and address using the connection c that is already +// connected to the SOCKS server. +// +// It returns the connection's local address assigned by the SOCKS +// server. +func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + a, err := d.connect(ctx, c, address) + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return a, nil +} + +// Dial connects to the provided address on the provided network. +// +// Unlike DialContext, it returns a raw transport connection instead +// of a forward proxy connection. +// +// Deprecated: Use DialContext or DialWithConn instead. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) + } else { + c, err = net.Dial(d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { + c.Close() + return nil, err + } + return c, nil +} + +func (d *Dialer) validateTarget(network, address string) error { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return errors.New("network not implemented") + } + switch d.cmd { + case CmdConnect, cmdBind: + default: + return errors.New("command not implemented") + } + return nil +} + +func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { + for i, s := range []string{d.proxyAddress, address} { + host, port, err := splitHostPort(s) + if err != nil { + return nil, nil, err + } + a := &Addr{Port: port} + a.IP = net.ParseIP(host) + if a.IP == nil { + a.Name = host + } + if i == 0 { + proxy = a + } else { + dst = a + } + } + return +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func NewDialer(network, address string) *Dialer { + return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect} +} + +const ( + authUsernamePasswordVersion = 0x01 + authStatusSucceeded = 0x00 +) + +// UsernamePassword are the credentials for the username/password +// authentication method. +type UsernamePassword struct { + Username string + Password string +} + +// Authenticate authenticates a pair of username and password with the +// proxy server. +func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error { + switch auth { + case AuthMethodNotRequired: + return nil + case AuthMethodUsernamePassword: + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 { + return errors.New("invalid username/password") + } + b := []byte{authUsernamePasswordVersion} + b = append(b, byte(len(up.Username))) + b = append(b, up.Username...) + b = append(b, byte(len(up.Password))) + b = append(b, up.Password...) + // TODO(mikio): handle IO deadlines and cancelation if + // necessary + if _, err := rw.Write(b); err != nil { + return err + } + if _, err := io.ReadFull(rw, b[:2]); err != nil { + return err + } + if b[0] != authUsernamePasswordVersion { + return errors.New("invalid username/password version") + } + if b[1] != authStatusSucceeded { + return errors.New("username/password authentication failed") + } + return nil + } + return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) +} diff --git a/vendor/golang.org/x/net/proxy/dial.go b/vendor/golang.org/x/net/proxy/dial.go new file mode 100644 index 0000000..811c2e4 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/dial.go @@ -0,0 +1,54 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" +) + +// A ContextDialer dials using a context. +type ContextDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment. +// +// The passed ctx is only used for returning the Conn, not the lifetime of the Conn. +// +// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer +// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout. +// +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func Dial(ctx context.Context, network, address string) (net.Conn, error) { + d := FromEnvironment() + if xd, ok := d.(ContextDialer); ok { + return xd.DialContext(ctx, network, address) + } + return dialContext(ctx, d, network, address) +} + +// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) { + var ( + conn net.Conn + done = make(chan struct{}, 1) + err error + ) + go func() { + conn, err = d.Dial(network, address) + close(done) + if conn != nil && ctx.Err() != nil { + conn.Close() + } + }() + select { + case <-ctx.Done(): + err = ctx.Err() + case <-done: + } + return conn, err +} diff --git a/vendor/golang.org/x/net/proxy/direct.go b/vendor/golang.org/x/net/proxy/direct.go new file mode 100644 index 0000000..3d66bde --- /dev/null +++ b/vendor/golang.org/x/net/proxy/direct.go @@ -0,0 +1,31 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" +) + +type direct struct{} + +// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext. +var Direct = direct{} + +var ( + _ Dialer = Direct + _ ContextDialer = Direct +) + +// Dial directly invokes net.Dial with the supplied parameters. +func (direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters. +func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) +} diff --git a/vendor/golang.org/x/net/proxy/per_host.go b/vendor/golang.org/x/net/proxy/per_host.go new file mode 100644 index 0000000..573fe79 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/per_host.go @@ -0,0 +1,155 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" + "strings" +) + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type PerHost struct { + def, bypass Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func NewPerHost(defaultDialer, bypass Dialer) *PerHost { + return &PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +// DialContext connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + d := p.dialerForRequest(host) + if x, ok := d.(ContextDialer); ok { + return x.DialContext(ctx, network, addr) + } + return dialContext(ctx, d, network, addr) +} + +func (p *PerHost) dialerForRequest(host string) Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} diff --git a/vendor/golang.org/x/net/proxy/proxy.go b/vendor/golang.org/x/net/proxy/proxy.go new file mode 100644 index 0000000..9ff4b9a --- /dev/null +++ b/vendor/golang.org/x/net/proxy/proxy.go @@ -0,0 +1,149 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proxy provides support for a variety of protocols to proxy network +// data. +package proxy // import "golang.org/x/net/proxy" + +import ( + "errors" + "net" + "net/url" + "os" + "sync" +) + +// A Dialer is a means to establish a connection. +// Custom dialers should also implement ContextDialer. +type Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy-related +// variables in the environment and makes underlying connections +// directly. +func FromEnvironment() Dialer { + return FromEnvironmentUsing(Direct) +} + +// FromEnvironmentUsing returns the dialer specify by the proxy-related +// variables in the environment and makes underlying connections +// using the provided forwarding Dialer (for instance, a *net.Dialer +// with desired configuration). +func FromEnvironmentUsing(forward Dialer) Dialer { + allProxy := allProxyEnv.Get() + if len(allProxy) == 0 { + return forward + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return forward + } + proxy, err := FromURL(proxyURL, forward) + if err != nil { + return forward + } + + noProxy := noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := NewPerHost(proxy, forward) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) { + if proxySchemes == nil { + proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error)) + } + proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func FromURL(u *url.URL, forward Dialer) (Dialer, error) { + var auth *Auth + if u.User != nil { + auth = new(Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5", "socks5h": + addr := u.Hostname() + port := u.Port() + if port == "" { + port = "1080" + } + return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxySchemes != nil { + if f, ok := proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + allProxyEnv = &envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} diff --git a/vendor/golang.org/x/net/proxy/socks5.go b/vendor/golang.org/x/net/proxy/socks5.go new file mode 100644 index 0000000..c91651f --- /dev/null +++ b/vendor/golang.org/x/net/proxy/socks5.go @@ -0,0 +1,42 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" + + "golang.org/x/net/internal/socks" +) + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given +// address with an optional username and password. +// See RFC 1928 and RFC 1929. +func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) { + d := socks.NewDialer(network, address) + if forward != nil { + if f, ok := forward.(ContextDialer); ok { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return f.DialContext(ctx, network, address) + } + } else { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return dialContext(ctx, forward, network, address) + } + } + } + if auth != nil { + up := socks.UsernamePassword{ + Username: auth.User, + Password: auth.Password, + } + d.AuthMethods = []socks.AuthMethod{ + socks.AuthMethodNotRequired, + socks.AuthMethodUsernamePassword, + } + d.Authenticate = up.Authenticate + } + return d, nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt new file mode 100644 index 0000000..2278348 --- /dev/null +++ b/vendor/modules.txt @@ -0,0 +1,4 @@ +# golang.org/x/net v0.14.0 +## explicit; go 1.17 +golang.org/x/net/internal/socks +golang.org/x/net/proxy diff --git a/x_net_proxy.go b/x_net_proxy.go deleted file mode 100644 index 2e668f6..0000000 --- a/x_net_proxy.go +++ /dev/null @@ -1,473 +0,0 @@ -// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. -//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy - -// Package proxy provides support for a variety of protocols to proxy network -// data. -// - -package websocket - -import ( - "errors" - "io" - "net" - "net/url" - "os" - "strconv" - "strings" - "sync" -) - -type proxy_direct struct{} - -// Direct is a direct proxy: one that makes network connections directly. -var proxy_Direct = proxy_direct{} - -func (proxy_direct) Dial(network, addr string) (net.Conn, error) { - return net.Dial(network, addr) -} - -// A PerHost directs connections to a default Dialer unless the host name -// requested matches one of a number of exceptions. -type proxy_PerHost struct { - def, bypass proxy_Dialer - - bypassNetworks []*net.IPNet - bypassIPs []net.IP - bypassZones []string - bypassHosts []string -} - -// NewPerHost returns a PerHost Dialer that directs connections to either -// defaultDialer or bypass, depending on whether the connection matches one of -// the configured rules. -func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { - return &proxy_PerHost{ - def: defaultDialer, - bypass: bypass, - } -} - -// Dial connects to the address addr on the given network through either -// defaultDialer or bypass. -func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - return p.dialerForRequest(host).Dial(network, addr) -} - -func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { - if ip := net.ParseIP(host); ip != nil { - for _, net := range p.bypassNetworks { - if net.Contains(ip) { - return p.bypass - } - } - for _, bypassIP := range p.bypassIPs { - if bypassIP.Equal(ip) { - return p.bypass - } - } - return p.def - } - - for _, zone := range p.bypassZones { - if strings.HasSuffix(host, zone) { - return p.bypass - } - if host == zone[1:] { - // For a zone ".example.com", we match "example.com" - // too. - return p.bypass - } - } - for _, bypassHost := range p.bypassHosts { - if bypassHost == host { - return p.bypass - } - } - return p.def -} - -// AddFromString parses a string that contains comma-separated values -// specifying hosts that should use the bypass proxy. Each value is either an -// IP address, a CIDR range, a zone (*.example.com) or a host name -// (localhost). A best effort is made to parse the string and errors are -// ignored. -func (p *proxy_PerHost) AddFromString(s string) { - hosts := strings.Split(s, ",") - for _, host := range hosts { - host = strings.TrimSpace(host) - if len(host) == 0 { - continue - } - if strings.Contains(host, "/") { - // We assume that it's a CIDR address like 127.0.0.0/8 - if _, net, err := net.ParseCIDR(host); err == nil { - p.AddNetwork(net) - } - continue - } - if ip := net.ParseIP(host); ip != nil { - p.AddIP(ip) - continue - } - if strings.HasPrefix(host, "*.") { - p.AddZone(host[1:]) - continue - } - p.AddHost(host) - } -} - -// AddIP specifies an IP address that will use the bypass proxy. Note that -// this will only take effect if a literal IP address is dialed. A connection -// to a named host will never match an IP. -func (p *proxy_PerHost) AddIP(ip net.IP) { - p.bypassIPs = append(p.bypassIPs, ip) -} - -// AddNetwork specifies an IP range that will use the bypass proxy. Note that -// this will only take effect if a literal IP address is dialed. A connection -// to a named host will never match. -func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { - p.bypassNetworks = append(p.bypassNetworks, net) -} - -// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of -// "example.com" matches "example.com" and all of its subdomains. -func (p *proxy_PerHost) AddZone(zone string) { - if strings.HasSuffix(zone, ".") { - zone = zone[:len(zone)-1] - } - if !strings.HasPrefix(zone, ".") { - zone = "." + zone - } - p.bypassZones = append(p.bypassZones, zone) -} - -// AddHost specifies a host name that will use the bypass proxy. -func (p *proxy_PerHost) AddHost(host string) { - if strings.HasSuffix(host, ".") { - host = host[:len(host)-1] - } - p.bypassHosts = append(p.bypassHosts, host) -} - -// A Dialer is a means to establish a connection. -type proxy_Dialer interface { - // Dial connects to the given address via the proxy. - Dial(network, addr string) (c net.Conn, err error) -} - -// Auth contains authentication parameters that specific Dialers may require. -type proxy_Auth struct { - User, Password string -} - -// FromEnvironment returns the dialer specified by the proxy related variables in -// the environment. -func proxy_FromEnvironment() proxy_Dialer { - allProxy := proxy_allProxyEnv.Get() - if len(allProxy) == 0 { - return proxy_Direct - } - - proxyURL, err := url.Parse(allProxy) - if err != nil { - return proxy_Direct - } - proxy, err := proxy_FromURL(proxyURL, proxy_Direct) - if err != nil { - return proxy_Direct - } - - noProxy := proxy_noProxyEnv.Get() - if len(noProxy) == 0 { - return proxy - } - - perHost := proxy_NewPerHost(proxy, proxy_Direct) - perHost.AddFromString(noProxy) - return perHost -} - -// proxySchemes is a map from URL schemes to a function that creates a Dialer -// from a URL with such a scheme. -var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) - -// RegisterDialerType takes a URL scheme and a function to generate Dialers from -// a URL with that scheme and a forwarding Dialer. Registered schemes are used -// by FromURL. -func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { - if proxy_proxySchemes == nil { - proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) - } - proxy_proxySchemes[scheme] = f -} - -// FromURL returns a Dialer given a URL specification and an underlying -// Dialer for it to make network requests. -func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { - var auth *proxy_Auth - if u.User != nil { - auth = new(proxy_Auth) - auth.User = u.User.Username() - if p, ok := u.User.Password(); ok { - auth.Password = p - } - } - - switch u.Scheme { - case "socks5": - return proxy_SOCKS5("tcp", u.Host, auth, forward) - } - - // If the scheme doesn't match any of the built-in schemes, see if it - // was registered by another package. - if proxy_proxySchemes != nil { - if f, ok := proxy_proxySchemes[u.Scheme]; ok { - return f(u, forward) - } - } - - return nil, errors.New("proxy: unknown scheme: " + u.Scheme) -} - -var ( - proxy_allProxyEnv = &proxy_envOnce{ - names: []string{"ALL_PROXY", "all_proxy"}, - } - proxy_noProxyEnv = &proxy_envOnce{ - names: []string{"NO_PROXY", "no_proxy"}, - } -) - -// envOnce looks up an environment variable (optionally by multiple -// names) once. It mitigates expensive lookups on some platforms -// (e.g. Windows). -// (Borrowed from net/http/transport.go) -type proxy_envOnce struct { - names []string - once sync.Once - val string -} - -func (e *proxy_envOnce) Get() string { - e.once.Do(e.init) - return e.val -} - -func (e *proxy_envOnce) init() { - for _, n := range e.names { - e.val = os.Getenv(n) - if e.val != "" { - return - } - } -} - -// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address -// with an optional username and password. See RFC 1928 and RFC 1929. -func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { - s := &proxy_socks5{ - network: network, - addr: addr, - forward: forward, - } - if auth != nil { - s.user = auth.User - s.password = auth.Password - } - - return s, nil -} - -type proxy_socks5 struct { - user, password string - network, addr string - forward proxy_Dialer -} - -const proxy_socks5Version = 5 - -const ( - proxy_socks5AuthNone = 0 - proxy_socks5AuthPassword = 2 -) - -const proxy_socks5Connect = 1 - -const ( - proxy_socks5IP4 = 1 - proxy_socks5Domain = 3 - proxy_socks5IP6 = 4 -) - -var proxy_socks5Errors = []string{ - "", - "general failure", - "connection forbidden", - "network unreachable", - "host unreachable", - "connection refused", - "TTL expired", - "command not supported", - "address type not supported", -} - -// Dial connects to the address addr on the given network via the SOCKS5 proxy. -func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { - switch network { - case "tcp", "tcp6", "tcp4": - default: - return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) - } - - conn, err := s.forward.Dial(s.network, s.addr) - if err != nil { - return nil, err - } - if err := s.connect(conn, addr); err != nil { - conn.Close() - return nil, err - } - return conn, nil -} - -// connect takes an existing connection to a socks5 proxy server, -// and commands the server to extend that connection to target, -// which must be a canonical address with a host and port. -func (s *proxy_socks5) connect(conn net.Conn, target string) error { - host, portStr, err := net.SplitHostPort(target) - if err != nil { - return err - } - - port, err := strconv.Atoi(portStr) - if err != nil { - return errors.New("proxy: failed to parse port number: " + portStr) - } - if port < 1 || port > 0xffff { - return errors.New("proxy: port number out of range: " + portStr) - } - - // the size here is just an estimate - buf := make([]byte, 0, 6+len(host)) - - buf = append(buf, proxy_socks5Version) - if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { - buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) - } else { - buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) - } - - if _, err := conn.Write(buf); err != nil { - return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - if buf[0] != 5 { - return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) - } - if buf[1] == 0xff { - return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") - } - - // See RFC 1929 - if buf[1] == proxy_socks5AuthPassword { - buf = buf[:0] - buf = append(buf, 1 /* password protocol version */) - buf = append(buf, uint8(len(s.user))) - buf = append(buf, s.user...) - buf = append(buf, uint8(len(s.password))) - buf = append(buf, s.password...) - - if _, err := conn.Write(buf); err != nil { - return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - if buf[1] != 0 { - return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") - } - } - - buf = buf[:0] - buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) - - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - buf = append(buf, proxy_socks5IP4) - ip = ip4 - } else { - buf = append(buf, proxy_socks5IP6) - } - buf = append(buf, ip...) - } else { - if len(host) > 255 { - return errors.New("proxy: destination host name too long: " + host) - } - buf = append(buf, proxy_socks5Domain) - buf = append(buf, byte(len(host))) - buf = append(buf, host...) - } - buf = append(buf, byte(port>>8), byte(port)) - - if _, err := conn.Write(buf); err != nil { - return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - if _, err := io.ReadFull(conn, buf[:4]); err != nil { - return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - failure := "unknown error" - if int(buf[1]) < len(proxy_socks5Errors) { - failure = proxy_socks5Errors[buf[1]] - } - - if len(failure) > 0 { - return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) - } - - bytesToDiscard := 0 - switch buf[3] { - case proxy_socks5IP4: - bytesToDiscard = net.IPv4len - case proxy_socks5IP6: - bytesToDiscard = net.IPv6len - case proxy_socks5Domain: - _, err := io.ReadFull(conn, buf[:1]) - if err != nil { - return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - bytesToDiscard = int(buf[0]) - default: - return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) - } - - if cap(buf) < bytesToDiscard { - buf = make([]byte, bytesToDiscard) - } else { - buf = buf[:bytesToDiscard] - } - if _, err := io.ReadFull(conn, buf); err != nil { - return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - // Also need to discard the port number - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) - } - - return nil -}