Skip to content

Commit

Permalink
Merge pull request #163 from euank/join-halfclose
Browse files Browse the repository at this point in the history
Propagate half-closes correctly in forward
  • Loading branch information
euank authored May 29, 2024
2 parents 958f163 + c67a6d2 commit 4917562
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
20 changes: 12 additions & 8 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,20 @@ func (fwd *forwarder) Wait() error {
// compile-time check that we're implementing the proper interface
var _ Forwarder = (*forwarder)(nil)

func join(ctx context.Context, left, right io.ReadWriter) {
func join(logger log15.Logger, left, right net.Conn) {
g := &sync.WaitGroup{}
g.Add(2)
go func() {
_, _ = io.Copy(left, right)
g.Done()
defer g.Done()
defer left.Close()
n, err := io.Copy(left, right)
logger.Debug("left join finished", "err", err, "bytes", n)
}()
go func() {
_, _ = io.Copy(right, left)
g.Done()
defer g.Done()
defer right.Close()
n, err := io.Copy(right, left)
logger.Debug("right join finished", "err", err, "bytes", n)
}()
g.Wait()
}
Expand All @@ -85,21 +89,21 @@ func forwardTunnel(ctx context.Context, tun Tunnel, url *url.URL) Forwarder {
if err != nil {
return err
}
logger.Debug("accept connection from", "address", conn.RemoteAddr())
fwdTasks.Add(1)

go func() {
ngrokConn := conn.(Conn)
defer ngrokConn.Close()

backend, err := openBackend(ctx, logger, tun, ngrokConn, url)
if err != nil {
defer ngrokConn.Close()
logger.Warn("failed to connect to backend url", "error", err)
fwdTasks.Done()
return
}

defer backend.Close()
join(ctx, ngrokConn, backend)
join(logger.New("url", url), ngrokConn, backend)
fwdTasks.Done()
}()
}
Expand Down
46 changes: 46 additions & 0 deletions forward_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package ngrok

import (
"errors"
"io"
"net"
"testing"

"github.com/inconshreveable/log15/v3"
"github.com/stretchr/testify/require"
)

func TestHalfCloseJoin(t *testing.T) {
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

waitSrvConn := make(chan net.Conn)
go func() {
srvConn, err := srv.Accept()
if err != nil {
panic(err)
}
waitSrvConn <- srvConn
}()

browser, ngrokEndpoint := net.Pipe()
agent, userService := net.Pipe()

waitJoinDone := make(chan struct{})
go func() {
defer close(waitJoinDone)
join(log15.New(), ngrokEndpoint, agent)
}()

_, err = browser.Write([]byte("hello world"))
require.NoError(t, err)
var b [len("hello world")]byte
_, err = userService.Read(b[:])
require.NoError(t, err)
require.Equal(t, []byte("hello world"), b[:])
browser.Close()
_, err = userService.Read(b[:])
require.Truef(t, errors.Is(err, io.EOF), "io.EOF expected, got %v", err)

<-waitJoinDone
}

0 comments on commit 4917562

Please sign in to comment.