Skip to content

Commit

Permalink
Handle StopTunnel request
Browse files Browse the repository at this point in the history
  • Loading branch information
bobzilladev committed Nov 6, 2023
1 parent 77b6cd9 commit 56dcb5e
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 30 deletions.
23 changes: 23 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,26 @@ func (e errSessionDial) Is(target error) bool {
_, ok := target.(errSessionDial)
return ok
}

// Generic ngrok error that requires no parsing
type ngrokError struct {
Message string
ErrCode string
}

func (m ngrokError) Error() string {
return m.Message + "\n\n" + m.ErrCode
}

func (m ngrokError) Msg() string {
return m.Message
}

func (m ngrokError) ErrorCode() string {
return m.ErrCode
}

func (e ngrokError) Is(target error) bool {
_, ok := target.(ngrokError)
return ok
}
4 changes: 4 additions & 0 deletions examples/http-full/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func run(ctx context.Context) error {
log.Println("session update:", sess)
return nil
}),
ngrok.WithStopTunnelHandler(func(ctx context.Context, sess ngrok.Session, clientID string, err error) error {
log.Println("tunnel stop:", sess, "clientid:", clientID, "error:", err)
return nil
}),
)
if err != nil {
return err
Expand Down
27 changes: 20 additions & 7 deletions examples/ngrok-forward-lite/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,32 @@ func main() {
}

func run(ctx context.Context, backend *url.URL) error {
fwd, err := ngrok.ListenAndForward(ctx,
backend,
config.HTTPEndpoint(),
sess, err := ngrok.Connect(ctx,
ngrok.WithAuthtokenFromEnv(),
ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}),
)
if err != nil {
return err
}

l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{
"url": fwd.URL(),
})
for {
fwd, err := sess.ListenAndForward(ctx,
backend,
config.HTTPEndpoint(),
)
if err != nil {
return err
}

return fwd.Wait()
l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{
"url": fwd.URL(),
})

err = fwd.Wait()
if err == nil {
return nil
}
l.Log(ctx, ngrok_log.LogLevelWarn, "tunnel accept error. now setting up a new listener.",
map[string]any{"err": err})
}
}
14 changes: 13 additions & 1 deletion internal/tunnel/client/raw_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type SessionHandler interface {
OnStop(*proto.Stop, HandlerRespFunc)
OnRestart(*proto.Restart, HandlerRespFunc)
OnUpdate(*proto.Update, HandlerRespFunc)
OnStopTunnel(*proto.StopTunnel, HandlerRespFunc) error
}

// A RawSession is a client session which handles authorization with the tunnel
Expand Down Expand Up @@ -75,7 +76,7 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp
req := proto.Auth{
ClientID: id,
Extra: extra,
Version: []string{proto.Version},
Version: proto.Version,
}
if err = s.rpc(proto.AuthReq, &req, &resp); err != nil {
return
Expand Down Expand Up @@ -201,6 +202,17 @@ func (s *rawSession) Accept() (netx.LoggedConn, error) {
if deserialize(&req) {
go s.handler.OnUpdate(&req, respFunc)
}
case proto.StopTunnelReq:
var req proto.StopTunnel
if deserialize(&req) {
// allow the handler to return an error to reconnect the session
err := s.handler.OnStopTunnel(&req, respFunc)
if err != nil {
// close the connection as this is going to reconnect
s.mux.Close()
}
return nil, err
}
default:
return netx.NewLoggedConn(s.Logger, raw, "type", "proxy", "sess", s.id), nil
}
Expand Down
4 changes: 3 additions & 1 deletion internal/tunnel/client/reconnecting.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ func (s *reconnectingSession) receive() {
// accept the next proxy connection
proxy, err := s.raw.Accept()
if err == nil {
go s.handleProxy(proxy)
if proxy != nil {
go s.handleProxy(proxy)
}
continue
}

Expand Down
13 changes: 13 additions & 0 deletions internal/tunnel/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type Session interface {
// Latency updates
Latency() <-chan time.Duration

// Close the tunnel with this clientID, with an error that will be reported
// from the tunnel's Accept() method.
CloseTunnel(clientID string, err error) error

// Closes the session
Close() error
}
Expand Down Expand Up @@ -176,6 +180,15 @@ func (s *session) SrvInfo() (proto.SrvInfoResp, error) {
return s.raw.SrvInfo()
}

func (s *session) CloseTunnel(clientId string, err error) error {
t, ok := s.getTunnel(clientId)
if !ok {
return proto.StringError("no tunnel found for client id " + clientId)
}
t.CloseWithError(err)
return nil
}

func (s *session) Close() error {
return s.raw.Close()
}
Expand Down
17 changes: 14 additions & 3 deletions internal/tunnel/client/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ type tunnel struct {
labels map[string]string
forwardsTo string

accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
closeError error // error to use on accept error after a tunnel close

shut shutdown // for clean shutdowns
}
Expand All @@ -54,6 +55,7 @@ func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsT
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
closeError: errors.New("Tunnel closed"),
}
}

Expand All @@ -69,6 +71,7 @@ func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
closeError: errors.New("Tunnel closed"),
}
}

Expand All @@ -83,11 +86,19 @@ func (t *tunnel) handleConn(r *ProxyConn) {
func (t *tunnel) Accept() (*ProxyConn, error) {
conn, ok := <-t.accept
if !ok {
return nil, errors.New("Tunnel closed")
return nil, t.closeError
}
return conn, nil
}

func (t *tunnel) CloseWithError(closeError error) {
t.closeError = closeError
// Skips the call to unlisten, since the remote has already rejected it.
t.shut.Shut(func() {
close(t.accept)
})
}

// Closes the Tunnel by asking the remote machine to deallocate its listener, or
// an error if the request failed.
func (t *tunnel) Close() (err error) {
Expand Down
18 changes: 13 additions & 5 deletions internal/tunnel/proto/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ const (
StartTunnelWithLabelReq ReqType = 7

// sent from the server to the client
ProxyReq ReqType = 3
RestartReq ReqType = 4
StopReq ReqType = 5
UpdateReq ReqType = 6
ProxyReq ReqType = 3
RestartReq ReqType = 4
StopReq ReqType = 5
UpdateReq ReqType = 6
StopTunnelReq ReqType = 9

// sent from client to the server
SrvInfoReq ReqType = 8
)

const Version = "2"
var Version = []string{"3", "2"} // integers in priority order

// Match the error code in the format (ERR_NGROK_\d+).
var ngrokErrorCodeRegex = regexp.MustCompile(`(ERR_NGROK_\d+)`)
Expand Down Expand Up @@ -400,6 +401,13 @@ type UpdateResp struct {
Error string // an error, if one
}

// This request is sent from the server to the ngrok agent to request a tunnel to close, with a notice to display to the user
type StopTunnel struct {
ClientID string // the tunnel to stop
Message string // an message to display to the user
ErrorCode string // an error code to display to the user. empty on OK
}

type SrvInfo struct{}

type SrvInfoResp struct {
Expand Down
62 changes: 50 additions & 12 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ type Session interface {
// forwarded to a new HTTP server and handled by the provided HTTP handler.
ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error)

// Close the tunnel with this clientID, with an error that will be reported
// from the tunnel's Accept() method.
CloseTunnel(clientID string, err error) error

// Close ends the ngrok session. All Tunnel objects created by Listen
// on this session will be closed.
Close() error
Expand Down Expand Up @@ -99,6 +103,9 @@ type SessionHeartbeatHandler func(ctx context.Context, sess Session, latency tim
// ServerCommandHandler is the callback type for [WithStopHandler]
type ServerCommandHandler func(ctx context.Context, sess Session) error

// StopTunnelHandler is the callback type for [WithStopTunnelHandler]
type StopTunnelHandler func(ctx context.Context, sess Session, clientID string, err error) error

// ConnectOption is passed to [Connect] to customize session connection and establishment.
type ConnectOption func(*connectConfig)

Expand Down Expand Up @@ -178,9 +185,10 @@ type connectConfig struct {
DisconnectHandler SessionDisconnectHandler
HeartbeatHandler SessionHeartbeatHandler

StopHandler ServerCommandHandler
RestartHandler ServerCommandHandler
UpdateHandler ServerCommandHandler
StopHandler ServerCommandHandler
RestartHandler ServerCommandHandler
UpdateHandler ServerCommandHandler
StopTunnelHandler StopTunnelHandler

remoteStopErr *string
remoteRestartErr *string
Expand Down Expand Up @@ -400,6 +408,17 @@ func WithStopHandler(handler ServerCommandHandler) ConnectOption {
}
}

// WithStopTunnelHandler configures a function which is called if a tunnel
// is stopped by the ngrok service. Use this option to detect
// when the ngrok tunnel has gone offline.
//
// If this function returns a non-nil error then the session will be reconnected.
func WithStopTunnelHandler(handler StopTunnelHandler) ConnectOption {
return func(cfg *connectConfig) {
cfg.StopTunnelHandler = handler
}
}

// WithRestartHandler configures a function which is called when the ngrok service
// requests that this [Session] restarts. Your application may choose to interpret
// this callback as a request to reconnect the [Session] or restart the entire process.
Expand Down Expand Up @@ -548,11 +567,12 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) {
stateChanges := make(chan error, 32)

callbackHandler := remoteCallbackHandler{
Logger: logger,
sess: session,
stopHandler: cfg.StopHandler,
restartHandler: cfg.RestartHandler,
updateHandler: cfg.UpdateHandler,
Logger: logger,
sess: session,
stopHandler: cfg.StopHandler,
restartHandler: cfg.RestartHandler,
updateHandler: cfg.UpdateHandler,
stopTunnelHandler: cfg.StopTunnelHandler,
}

rawDialer := func() (tunnel_client.RawSession, error) {
Expand Down Expand Up @@ -771,6 +791,10 @@ func (s *sessionImpl) setInner(raw *sessionInner) {
atomic.StorePointer(&s.raw, unsafe.Pointer(raw))
}

func (s *sessionImpl) CloseTunnel(clientID string, err error) error {
return s.inner().CloseTunnel(clientID, err)
}

func (s *sessionImpl) Close() error {
return s.inner().Close()
}
Expand Down Expand Up @@ -908,10 +932,11 @@ func (s *sessionImpl) Latency() <-chan time.Duration {

type remoteCallbackHandler struct {
log15.Logger
sess Session
stopHandler ServerCommandHandler
restartHandler ServerCommandHandler
updateHandler ServerCommandHandler
sess Session
stopHandler ServerCommandHandler
restartHandler ServerCommandHandler
updateHandler ServerCommandHandler
stopTunnelHandler StopTunnelHandler
}

func (rc remoteCallbackHandler) OnStop(_ *proto.Stop, respond tunnel_client.HandlerRespFunc) {
Expand Down Expand Up @@ -959,3 +984,16 @@ func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client.
}
}
}

func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) error {
err := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode}
if rc.stopTunnelHandler != nil {
// allow user to return an error and instigate a reconnect
user_err := rc.stopTunnelHandler(context.TODO(), rc.sess, stopTunnel.ClientID, err)
if user_err != nil {
return user_err
}
}
// default behavior is to close the tunnel and maintain the session
return rc.sess.CloseTunnel(stopTunnel.ClientID, err)
}
8 changes: 7 additions & 1 deletion tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ type tunnelImpl struct {
func (t *tunnelImpl) Accept() (net.Conn, error) {
conn, err := t.Tunnel.Accept()
if err != nil {
return nil, errAcceptFailed{Inner: err}
err = errAcceptFailed{Inner: err}
if s, ok := t.Sess.(*sessionImpl); ok {
if si := s.inner(); si != nil {
si.Logger.Info(err.Error(), "clientid", t.Tunnel.ID())
}
}
return nil, err
}
return &connImpl{
Conn: conn.Conn,
Expand Down

0 comments on commit 56dcb5e

Please sign in to comment.