Skip to content

Commit

Permalink
Connect related changes - add optional hooks (#504)
Browse files Browse the repository at this point in the history
* adding return value to connectReqHandler

* adding optional connectRespHandler

* calling Hijack in handleHttps if defined

* fixing PR comments

* Update https.go

---------

Co-authored-by: Roman Manz <[email protected]>
Co-authored-by: Roman Manz <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2025
1 parent 0003d27 commit 44bab45
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
3 changes: 2 additions & 1 deletion examples/cascadeproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ func main() {
// socks5://localhost:8082
return url.Parse("http://localhost:8082")
}
connectReqHandler := func(req *http.Request) {
connectReqHandler := func(req *http.Request) error {
SetBasicAuth(username, password, req)
return nil
}
middleProxy.ConnectDial = middleProxy.NewConnectDialToProxyWithHandler("http://localhost:8082", connectReqHandler)

Expand Down
50 changes: 42 additions & 8 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
return
}
ctx.Logf("Accepting CONNECT to %s", host)
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
if todo.Hijack != nil {
todo.Hijack(r, proxyClient, ctx)
} else {
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
}

targetTCP, targetOK := targetSiteCon.(halfClosable)
proxyClientTCP, clientOK := proxyClient.(halfClosable)
Expand Down Expand Up @@ -194,7 +198,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
case ConnectHijack:
todo.Hijack(r, proxyClient, ctx)
case ConnectHTTPMitm:
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
if todo.Hijack != nil {
todo.Hijack(r, proxyClient, ctx)
} else {
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
}
ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it")

var targetSiteCon net.Conn
Expand Down Expand Up @@ -265,7 +273,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
}
}
case ConnectMitm:
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
if todo.Hijack != nil {
todo.Hijack(r, proxyClient, ctx)
} else {
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
}
ctx.Logf("Assuming CONNECT is TLS, mitm proxying it")
// this goes in a separate goroutine, so that the net/http server won't think we're
// still handling the request even after hijacking the connection. Those HTTP CONNECT
Expand Down Expand Up @@ -534,7 +546,15 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxy(httpsProxy string) func(netw

func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
httpsProxy string,
connectReqHandler func(req *http.Request),
connectReqHandler func(req *http.Request) error,
) func(network, addr string) (net.Conn, error) {
return proxy.NewConnectDialToProxyWithMoreHandlers(httpsProxy, connectReqHandler, nil)
}

func (proxy *ProxyHttpServer) NewConnectDialToProxyWithMoreHandlers(
httpsProxy string,
connectReqHandler func(req *http.Request) error,
connectRespHandler func(req *http.Response) error,
) func(network, addr string) (net.Conn, error) {
u, err := url.Parse(httpsProxy)
if err != nil {
Expand All @@ -552,7 +572,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
Header: make(http.Header),
}
if connectReqHandler != nil {
connectReqHandler(connectReq)
if err := connectReqHandler(connectReq); err != nil {
return nil, err
}
}
c, err := proxy.dial(&ProxyCtx{Req: &http.Request{}}, network, u.Host)
if err != nil {
Expand All @@ -569,7 +591,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if connectRespHandler != nil {
if err := connectRespHandler(resp); err != nil {
c.Close()
return nil, err
}
} else if resp.StatusCode != http.StatusOK {
resp, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength))
if err != nil {
return nil, err
Expand Down Expand Up @@ -603,7 +630,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
Header: make(http.Header),
}
if connectReqHandler != nil {
connectReqHandler(connectReq)
if err := connectReqHandler(connectReq); err != nil {
return nil, err
}
}
_ = connectReq.Write(c)
// Read response.
Expand All @@ -616,7 +645,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if connectRespHandler != nil {
if err := connectRespHandler(resp); err != nil {
c.Close()
return nil, err
}
} else if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength))
if err != nil {
return nil, err
Expand Down

0 comments on commit 44bab45

Please sign in to comment.