From 44bab45066193086a2d446b582f92a40ef2f65fe Mon Sep 17 00:00:00 2001 From: RomanManz Date: Mon, 24 Feb 2025 17:27:54 +0100 Subject: [PATCH] Connect related changes - add optional hooks (#504) * adding return value to connectReqHandler * adding optional connectRespHandler * calling Hijack in handleHttps if defined * fixing PR comments * Update https.go --------- Co-authored-by: Roman Manz Co-authored-by: Roman Manz --- examples/cascadeproxy/main.go | 3 ++- https.go | 50 +++++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/examples/cascadeproxy/main.go b/examples/cascadeproxy/main.go index 2abd4a09..ac71c606 100644 --- a/examples/cascadeproxy/main.go +++ b/examples/cascadeproxy/main.go @@ -41,8 +41,9 @@ func main() { // socks5://localhost:8082 return url.Parse("http://localhost:8082") } - connectReqHandler := func(req *http.Request) { + connectReqHandler := func(req *http.Request) error { SetBasicAuth(username, password, req) + return nil } middleProxy.ConnectDial = middleProxy.NewConnectDialToProxyWithHandler("http://localhost:8082", connectReqHandler) diff --git a/https.go b/https.go index 016e5134..2bafacc2 100644 --- a/https.go +++ b/https.go @@ -143,7 +143,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request return } ctx.Logf("Accepting CONNECT to %s", host) - _, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) + if todo.Hijack != nil { + todo.Hijack(r, proxyClient, ctx) + } else { + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) + } targetTCP, targetOK := targetSiteCon.(halfClosable) proxyClientTCP, clientOK := proxyClient.(halfClosable) @@ -194,7 +198,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request case ConnectHijack: todo.Hijack(r, proxyClient, ctx) case ConnectHTTPMitm: - _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + if todo.Hijack != nil { + todo.Hijack(r, proxyClient, ctx) + } else { + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + } ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") var targetSiteCon net.Conn @@ -265,7 +273,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request } } case ConnectMitm: - _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + if todo.Hijack != nil { + todo.Hijack(r, proxyClient, ctx) + } else { + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + } ctx.Logf("Assuming CONNECT is TLS, mitm proxying it") // this goes in a separate goroutine, so that the net/http server won't think we're // still handling the request even after hijacking the connection. Those HTTP CONNECT @@ -534,7 +546,15 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxy(httpsProxy string) func(netw func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( httpsProxy string, - connectReqHandler func(req *http.Request), + connectReqHandler func(req *http.Request) error, +) func(network, addr string) (net.Conn, error) { + return proxy.NewConnectDialToProxyWithMoreHandlers(httpsProxy, connectReqHandler, nil) +} + +func (proxy *ProxyHttpServer) NewConnectDialToProxyWithMoreHandlers( + httpsProxy string, + connectReqHandler func(req *http.Request) error, + connectRespHandler func(req *http.Response) error, ) func(network, addr string) (net.Conn, error) { u, err := url.Parse(httpsProxy) if err != nil { @@ -552,7 +572,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( Header: make(http.Header), } if connectReqHandler != nil { - connectReqHandler(connectReq) + if err := connectReqHandler(connectReq); err != nil { + return nil, err + } } c, err := proxy.dial(&ProxyCtx{Req: &http.Request{}}, network, u.Host) if err != nil { @@ -569,7 +591,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( return nil, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if connectRespHandler != nil { + if err := connectRespHandler(resp); err != nil { + c.Close() + return nil, err + } + } else if resp.StatusCode != http.StatusOK { resp, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength)) if err != nil { return nil, err @@ -603,7 +630,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( Header: make(http.Header), } if connectReqHandler != nil { - connectReqHandler(connectReq) + if err := connectReqHandler(connectReq); err != nil { + return nil, err + } } _ = connectReq.Write(c) // Read response. @@ -616,7 +645,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( return nil, err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if connectRespHandler != nil { + if err := connectRespHandler(resp); err != nil { + c.Close() + return nil, err + } + } else if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength)) if err != nil { return nil, err