From 56dcb5ec8e61d5e138cfabc71168438fc3a09a96 Mon Sep 17 00:00:00 2001 From: bobzilladev Date: Wed, 1 Nov 2023 18:28:55 +0000 Subject: [PATCH] Handle StopTunnel request --- errors.go | 23 ++++++++++ examples/http-full/main.go | 4 ++ examples/ngrok-forward-lite/main.go | 27 ++++++++--- internal/tunnel/client/raw_session.go | 14 +++++- internal/tunnel/client/reconnecting.go | 4 +- internal/tunnel/client/session.go | 13 ++++++ internal/tunnel/client/tunnel.go | 17 +++++-- internal/tunnel/proto/msg.go | 18 +++++--- session.go | 62 +++++++++++++++++++++----- tunnel.go | 8 +++- 10 files changed, 160 insertions(+), 30 deletions(-) diff --git a/errors.go b/errors.go index cb22d4a6..44485381 100644 --- a/errors.go +++ b/errors.go @@ -132,3 +132,26 @@ func (e errSessionDial) Is(target error) bool { _, ok := target.(errSessionDial) return ok } + +// Generic ngrok error that requires no parsing +type ngrokError struct { + Message string + ErrCode string +} + +func (m ngrokError) Error() string { + return m.Message + "\n\n" + m.ErrCode +} + +func (m ngrokError) Msg() string { + return m.Message +} + +func (m ngrokError) ErrorCode() string { + return m.ErrCode +} + +func (e ngrokError) Is(target error) bool { + _, ok := target.(ngrokError) + return ok +} diff --git a/examples/http-full/main.go b/examples/http-full/main.go index 2834ddba..43017faa 100644 --- a/examples/http-full/main.go +++ b/examples/http-full/main.go @@ -77,6 +77,10 @@ func run(ctx context.Context) error { log.Println("session update:", sess) return nil }), + ngrok.WithStopTunnelHandler(func(ctx context.Context, sess ngrok.Session, clientID string, err error) error { + log.Println("tunnel stop:", sess, "clientid:", clientID, "error:", err) + return nil + }), ) if err != nil { return err diff --git a/examples/ngrok-forward-lite/main.go b/examples/ngrok-forward-lite/main.go index 4f592a93..c3490a5c 100644 --- a/examples/ngrok-forward-lite/main.go +++ b/examples/ngrok-forward-lite/main.go @@ -57,9 +57,7 @@ func main() { } func run(ctx context.Context, backend *url.URL) error { - fwd, err := ngrok.ListenAndForward(ctx, - backend, - config.HTTPEndpoint(), + sess, err := ngrok.Connect(ctx, ngrok.WithAuthtokenFromEnv(), ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), ) @@ -67,9 +65,24 @@ func run(ctx context.Context, backend *url.URL) error { return err } - l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{ - "url": fwd.URL(), - }) + for { + fwd, err := sess.ListenAndForward(ctx, + backend, + config.HTTPEndpoint(), + ) + if err != nil { + return err + } - return fwd.Wait() + l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{ + "url": fwd.URL(), + }) + + err = fwd.Wait() + if err == nil { + return nil + } + l.Log(ctx, ngrok_log.LogLevelWarn, "tunnel accept error. now setting up a new listener.", + map[string]any{"err": err}) + } } diff --git a/internal/tunnel/client/raw_session.go b/internal/tunnel/client/raw_session.go index 224dba2f..60ea0287 100644 --- a/internal/tunnel/client/raw_session.go +++ b/internal/tunnel/client/raw_session.go @@ -37,6 +37,7 @@ type SessionHandler interface { OnStop(*proto.Stop, HandlerRespFunc) OnRestart(*proto.Restart, HandlerRespFunc) OnUpdate(*proto.Update, HandlerRespFunc) + OnStopTunnel(*proto.StopTunnel, HandlerRespFunc) error } // A RawSession is a client session which handles authorization with the tunnel @@ -75,7 +76,7 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp req := proto.Auth{ ClientID: id, Extra: extra, - Version: []string{proto.Version}, + Version: proto.Version, } if err = s.rpc(proto.AuthReq, &req, &resp); err != nil { return @@ -201,6 +202,17 @@ func (s *rawSession) Accept() (netx.LoggedConn, error) { if deserialize(&req) { go s.handler.OnUpdate(&req, respFunc) } + case proto.StopTunnelReq: + var req proto.StopTunnel + if deserialize(&req) { + // allow the handler to return an error to reconnect the session + err := s.handler.OnStopTunnel(&req, respFunc) + if err != nil { + // close the connection as this is going to reconnect + s.mux.Close() + } + return nil, err + } default: return netx.NewLoggedConn(s.Logger, raw, "type", "proxy", "sess", s.id), nil } diff --git a/internal/tunnel/client/reconnecting.go b/internal/tunnel/client/reconnecting.go index edc5b8ad..9bbea933 100644 --- a/internal/tunnel/client/reconnecting.go +++ b/internal/tunnel/client/reconnecting.go @@ -167,7 +167,9 @@ func (s *reconnectingSession) receive() { // accept the next proxy connection proxy, err := s.raw.Accept() if err == nil { - go s.handleProxy(proxy) + if proxy != nil { + go s.handleProxy(proxy) + } continue } diff --git a/internal/tunnel/client/session.go b/internal/tunnel/client/session.go index 398dfade..701e4e25 100644 --- a/internal/tunnel/client/session.go +++ b/internal/tunnel/client/session.go @@ -67,6 +67,10 @@ type Session interface { // Latency updates Latency() <-chan time.Duration + // Close the tunnel with this clientID, with an error that will be reported + // from the tunnel's Accept() method. + CloseTunnel(clientID string, err error) error + // Closes the session Close() error } @@ -176,6 +180,15 @@ func (s *session) SrvInfo() (proto.SrvInfoResp, error) { return s.raw.SrvInfo() } +func (s *session) CloseTunnel(clientId string, err error) error { + t, ok := s.getTunnel(clientId) + if !ok { + return proto.StringError("no tunnel found for client id " + clientId) + } + t.CloseWithError(err) + return nil +} + func (s *session) Close() error { return s.raw.Close() } diff --git a/internal/tunnel/client/tunnel.go b/internal/tunnel/client/tunnel.go index 5cb18711..b3ff8fc7 100644 --- a/internal/tunnel/client/tunnel.go +++ b/internal/tunnel/client/tunnel.go @@ -35,8 +35,9 @@ type tunnel struct { labels map[string]string forwardsTo string - accept chan *ProxyConn // new connections come on this channel - unlisten func() error // call this function to close the tunnel + accept chan *ProxyConn // new connections come on this channel + unlisten func() error // call this function to close the tunnel + closeError error // error to use on accept error after a tunnel close shut shutdown // for clean shutdowns } @@ -54,6 +55,7 @@ func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsT accept: make(chan *ProxyConn), unlisten: func() error { return s.unlisten(resp.ClientID) }, forwardsTo: forwardsTo, + closeError: errors.New("Tunnel closed"), } } @@ -69,6 +71,7 @@ func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels accept: make(chan *ProxyConn), unlisten: func() error { return s.unlisten(resp.ID) }, forwardsTo: forwardsTo, + closeError: errors.New("Tunnel closed"), } } @@ -83,11 +86,19 @@ func (t *tunnel) handleConn(r *ProxyConn) { func (t *tunnel) Accept() (*ProxyConn, error) { conn, ok := <-t.accept if !ok { - return nil, errors.New("Tunnel closed") + return nil, t.closeError } return conn, nil } +func (t *tunnel) CloseWithError(closeError error) { + t.closeError = closeError + // Skips the call to unlisten, since the remote has already rejected it. + t.shut.Shut(func() { + close(t.accept) + }) +} + // Closes the Tunnel by asking the remote machine to deallocate its listener, or // an error if the request failed. func (t *tunnel) Close() (err error) { diff --git a/internal/tunnel/proto/msg.go b/internal/tunnel/proto/msg.go index 34b17562..13fa2594 100644 --- a/internal/tunnel/proto/msg.go +++ b/internal/tunnel/proto/msg.go @@ -23,16 +23,17 @@ const ( StartTunnelWithLabelReq ReqType = 7 // sent from the server to the client - ProxyReq ReqType = 3 - RestartReq ReqType = 4 - StopReq ReqType = 5 - UpdateReq ReqType = 6 + ProxyReq ReqType = 3 + RestartReq ReqType = 4 + StopReq ReqType = 5 + UpdateReq ReqType = 6 + StopTunnelReq ReqType = 9 // sent from client to the server SrvInfoReq ReqType = 8 ) -const Version = "2" +var Version = []string{"3", "2"} // integers in priority order // Match the error code in the format (ERR_NGROK_\d+). var ngrokErrorCodeRegex = regexp.MustCompile(`(ERR_NGROK_\d+)`) @@ -400,6 +401,13 @@ type UpdateResp struct { Error string // an error, if one } +// This request is sent from the server to the ngrok agent to request a tunnel to close, with a notice to display to the user +type StopTunnel struct { + ClientID string // the tunnel to stop + Message string // an message to display to the user + ErrorCode string // an error code to display to the user. empty on OK +} + type SrvInfo struct{} type SrvInfoResp struct { diff --git a/session.go b/session.go index 64163e1f..0da0c03d 100644 --- a/session.go +++ b/session.go @@ -66,6 +66,10 @@ type Session interface { // forwarded to a new HTTP server and handled by the provided HTTP handler. ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error) + // Close the tunnel with this clientID, with an error that will be reported + // from the tunnel's Accept() method. + CloseTunnel(clientID string, err error) error + // Close ends the ngrok session. All Tunnel objects created by Listen // on this session will be closed. Close() error @@ -99,6 +103,9 @@ type SessionHeartbeatHandler func(ctx context.Context, sess Session, latency tim // ServerCommandHandler is the callback type for [WithStopHandler] type ServerCommandHandler func(ctx context.Context, sess Session) error +// StopTunnelHandler is the callback type for [WithStopTunnelHandler] +type StopTunnelHandler func(ctx context.Context, sess Session, clientID string, err error) error + // ConnectOption is passed to [Connect] to customize session connection and establishment. type ConnectOption func(*connectConfig) @@ -178,9 +185,10 @@ type connectConfig struct { DisconnectHandler SessionDisconnectHandler HeartbeatHandler SessionHeartbeatHandler - StopHandler ServerCommandHandler - RestartHandler ServerCommandHandler - UpdateHandler ServerCommandHandler + StopHandler ServerCommandHandler + RestartHandler ServerCommandHandler + UpdateHandler ServerCommandHandler + StopTunnelHandler StopTunnelHandler remoteStopErr *string remoteRestartErr *string @@ -400,6 +408,17 @@ func WithStopHandler(handler ServerCommandHandler) ConnectOption { } } +// WithStopTunnelHandler configures a function which is called if a tunnel +// is stopped by the ngrok service. Use this option to detect +// when the ngrok tunnel has gone offline. +// +// If this function returns a non-nil error then the session will be reconnected. +func WithStopTunnelHandler(handler StopTunnelHandler) ConnectOption { + return func(cfg *connectConfig) { + cfg.StopTunnelHandler = handler + } +} + // WithRestartHandler configures a function which is called when the ngrok service // requests that this [Session] restarts. Your application may choose to interpret // this callback as a request to reconnect the [Session] or restart the entire process. @@ -548,11 +567,12 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) { stateChanges := make(chan error, 32) callbackHandler := remoteCallbackHandler{ - Logger: logger, - sess: session, - stopHandler: cfg.StopHandler, - restartHandler: cfg.RestartHandler, - updateHandler: cfg.UpdateHandler, + Logger: logger, + sess: session, + stopHandler: cfg.StopHandler, + restartHandler: cfg.RestartHandler, + updateHandler: cfg.UpdateHandler, + stopTunnelHandler: cfg.StopTunnelHandler, } rawDialer := func() (tunnel_client.RawSession, error) { @@ -771,6 +791,10 @@ func (s *sessionImpl) setInner(raw *sessionInner) { atomic.StorePointer(&s.raw, unsafe.Pointer(raw)) } +func (s *sessionImpl) CloseTunnel(clientID string, err error) error { + return s.inner().CloseTunnel(clientID, err) +} + func (s *sessionImpl) Close() error { return s.inner().Close() } @@ -908,10 +932,11 @@ func (s *sessionImpl) Latency() <-chan time.Duration { type remoteCallbackHandler struct { log15.Logger - sess Session - stopHandler ServerCommandHandler - restartHandler ServerCommandHandler - updateHandler ServerCommandHandler + sess Session + stopHandler ServerCommandHandler + restartHandler ServerCommandHandler + updateHandler ServerCommandHandler + stopTunnelHandler StopTunnelHandler } func (rc remoteCallbackHandler) OnStop(_ *proto.Stop, respond tunnel_client.HandlerRespFunc) { @@ -959,3 +984,16 @@ func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client. } } } + +func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) error { + err := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode} + if rc.stopTunnelHandler != nil { + // allow user to return an error and instigate a reconnect + user_err := rc.stopTunnelHandler(context.TODO(), rc.sess, stopTunnel.ClientID, err) + if user_err != nil { + return user_err + } + } + // default behavior is to close the tunnel and maintain the session + return rc.sess.CloseTunnel(stopTunnel.ClientID, err) +} diff --git a/tunnel.go b/tunnel.go index 4b9db688..a8783288 100644 --- a/tunnel.go +++ b/tunnel.go @@ -149,7 +149,13 @@ type tunnelImpl struct { func (t *tunnelImpl) Accept() (net.Conn, error) { conn, err := t.Tunnel.Accept() if err != nil { - return nil, errAcceptFailed{Inner: err} + err = errAcceptFailed{Inner: err} + if s, ok := t.Sess.(*sessionImpl); ok { + if si := s.inner(); si != nil { + si.Logger.Info(err.Error(), "clientid", t.Tunnel.ID()) + } + } + return nil, err } return &connImpl{ Conn: conn.Conn,