Skip to content

Commit

Permalink
Merge pull request #686 from adam-p/clientlib-dial
Browse files Browse the repository at this point in the history
Add Dial method

(cherry picked from commit 1150f12)
  • Loading branch information
rod-hynes committed Jun 14, 2024
1 parent c1011b0 commit d60bc23
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 44 deletions.
146 changes: 105 additions & 41 deletions ClientLibrary/clientlib/clientlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ import (
"encoding/json"
std_errors "errors"
"fmt"
"io"
"net"
"path/filepath"
"sync"
"sync/atomic"

"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
Expand Down Expand Up @@ -65,14 +68,22 @@ type Parameters struct {
// notices to noticeReceiver. Has no effect unless the tunnel
// config.EmitDiagnosticNotices flag is set.
EmitDiagnosticNoticesToFiles bool

// DisableLocalSocksProxy disables running the local SOCKS proxy.
DisableLocalSocksProxy *bool

// DisableLocalHTTPProxy disables running the local HTTP proxy.
DisableLocalHTTPProxy *bool
}

// PsiphonTunnel is the tunnel object. It can be used for stopping the tunnel and
// retrieving proxy ports.
type PsiphonTunnel struct {
mu sync.Mutex
stop func()
embeddedServerListWaitGroup sync.WaitGroup
controllerWaitGroup sync.WaitGroup
stopController context.CancelFunc
controllerDial func(string, net.Conn) (net.Conn, error)

// The port on which the HTTP proxy is running
HTTPProxyPort int
Expand All @@ -95,6 +106,10 @@ type NoticeEvent struct {

// ErrTimeout is returned when the tunnel establishment attempt fails due to timeout
var ErrTimeout = std_errors.New("clientlib: tunnel establishment timeout")
var errMultipleStart = std_errors.New("clientlib: StartTunnel called multiple times")

// started is used to ensure that only one tunnel is started at a time
var started atomic.Bool

// StartTunnel establishes a Psiphon tunnel. It returns an error if the establishment
// was not successful. If the returned error is nil, the returned tunnel can be used
Expand Down Expand Up @@ -122,6 +137,10 @@ func StartTunnel(
paramsDelta ParametersDelta,
noticeReceiver func(NoticeEvent)) (retTunnel *PsiphonTunnel, retErr error) {

if !started.CompareAndSwap(false, true) {
return nil, errMultipleStart
}

config, err := psiphon.LoadConfig(configJSON)
if err != nil {
return nil, errors.TraceMsg(err, "failed to load config file")
Expand Down Expand Up @@ -156,6 +175,14 @@ func StartTunnel(
}
} // else use the value in the config

if params.DisableLocalSocksProxy != nil {
config.DisableLocalSocksProxy = *params.DisableLocalSocksProxy
} // else use the value in the config

if params.DisableLocalHTTPProxy != nil {
config.DisableLocalHTTPProxy = *params.DisableLocalHTTPProxy
} // else use the value in the config

// config.Commit must be called before calling config.SetParameters
// or attempting to connect.
err = config.Commit(true)
Expand All @@ -167,15 +194,14 @@ func StartTunnel(
if len(paramsDelta) > 0 {
err = config.SetParameters("", false, paramsDelta)
if err != nil {
return nil, errors.TraceMsg(
err, fmt.Sprintf("SetParameters failed for delta: %v", paramsDelta))
return nil, errors.TraceMsg(err, fmt.Sprintf("SetParameters failed for delta: %v", paramsDelta))
}
}

// Will receive a value when the tunnel has successfully connected.
connected := make(chan struct{}, 1)
// Will receive a value if an error occurs during the connection sequence.
errored := make(chan error, 1)
// Will be closed when the tunnel has successfully connected
connectedSignal := make(chan struct{})
// Will receive a value if an error occurs during the connection sequence
erroredCh := make(chan error, 1)

// Create the tunnel object
tunnel := new(PsiphonTunnel)
Expand All @@ -190,7 +216,7 @@ func StartTunnel(
// We'll interpret it as a connection error and abort.
err = errors.TraceMsg(err, "failed to unmarshal notice JSON")
select {
case errored <- err:
case erroredCh <- err:
default:
}
return
Expand All @@ -204,16 +230,13 @@ func StartTunnel(
tunnel.SOCKSProxyPort = int(port)
} else if event.Type == "EstablishTunnelTimeout" {
select {
case errored <- ErrTimeout:
case erroredCh <- ErrTimeout:
default:
}
} else if event.Type == "Tunnels" {
count := event.Data["count"].(float64)
if count > 0 {
select {
case connected <- struct{}{}:
default:
}
close(connectedSignal)
}
}

Expand All @@ -228,19 +251,30 @@ func StartTunnel(
if err != nil {
return nil, errors.TraceMsg(err, "failed to open data store")
}
// Make sure we close the datastore in case of error

// Create a cancelable context that will be used for stopping the tunnel
tunnelCtx, cancelTunnelCtx := context.WithCancel(ctx)

// Because the tunnel object is only returned on success, there are at least two
// problems that we don't need to worry about:
// 1. This stop function is called both by the error-defer here and by a call to the
// tunnel's Stop method.
// 2. This stop function is called via the tunnel's Stop method before the WaitGroups
// are incremented (causing a race condition).
tunnel.stop = func() {
cancelTunnelCtx()
tunnel.embeddedServerListWaitGroup.Wait()
tunnel.controllerWaitGroup.Wait()
psiphon.CloseDataStore()
started.Store(false)
}

defer func() {
if retErr != nil {
tunnel.controllerWaitGroup.Wait()
tunnel.embeddedServerListWaitGroup.Wait()
psiphon.CloseDataStore()
tunnel.stop()
}
}()

// Create a cancelable context that will be used for stopping the tunnel
var controllerCtx context.Context
controllerCtx, tunnel.stopController = context.WithCancel(ctx)

// If specified, the embedded server list is loaded and stored. When there
// are no server candidates at all, we wait for this import to complete
// before starting the Psiphon controller. Otherwise, we import while
Expand All @@ -258,7 +292,7 @@ func StartTunnel(
defer tunnel.embeddedServerListWaitGroup.Done()

err := psiphon.ImportEmbeddedServerEntries(
controllerCtx,
tunnelCtx,
config,
"",
embeddedServerEntryList)
Expand All @@ -275,60 +309,90 @@ func StartTunnel(
// Create the Psiphon controller
controller, err := psiphon.NewController(config)
if err != nil {
tunnel.stopController()
tunnel.embeddedServerListWaitGroup.Wait()
return nil, errors.TraceMsg(err, "psiphon.NewController failed")
}

tunnel.controllerDial = controller.Dial

// Begin tunnel connection
tunnel.controllerWaitGroup.Add(1)
go func() {
defer tunnel.controllerWaitGroup.Done()

// Start the tunnel. Only returns on error (or internal timeout).
controller.Run(controllerCtx)
controller.Run(tunnelCtx)

// controller.Run does not exit until the goroutine that posts
// EstablishTunnelTimeout has terminated; so, if there was a
// EstablishTunnelTimeout event, ErrTimeout is guaranteed to be sent to
// errord before this next error and will be the StartTunnel return value.
// errored before this next error and will be the StartTunnel return value.

var err error
switch ctx.Err() {
err := ctx.Err()
switch err {
case context.DeadlineExceeded:
err = ErrTimeout
case context.Canceled:
err = errors.TraceNew("StartTunnel canceled")
err = errors.TraceMsg(err, "StartTunnel canceled")
default:
err = errors.TraceNew("controller.Run exited unexpectedly")
err = errors.TraceMsg(err, "controller.Run exited unexpectedly")
}
select {
case errored <- err:
case erroredCh <- err:
default:
}
}()

// Wait for an active tunnel or error
select {
case <-connected:
case <-connectedSignal:
return tunnel, nil
case err := <-errored:
tunnel.Stop()
case err := <-erroredCh:
if err != ErrTimeout {
err = errors.TraceMsg(err, "tunnel start produced error")
}
return nil, err
}
}

// Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
// Not safe to call concurrently with Start.
// Stop stops/disconnects/shuts down the tunnel.
// It is safe to call Stop multiple times.
// It is safe to call concurrently with Dial and with itself.
func (tunnel *PsiphonTunnel) Stop() {
if tunnel.stopController == nil {
// Holding a lock while calling the stop function ensures that any concurrent call
// to Stop will wait for the first call to finish before returning, rather than
// returning immediately (because tunnel.stop is nil) and thereby indicating
// (erroneously) that the tunnel has been stopped.
// Stopping a tunnel happens quickly enough that this processing block shouldn't be
// a problem.
tunnel.mu.Lock()
defer tunnel.mu.Unlock()

if tunnel.stop == nil {
return
}
tunnel.stopController()
tunnel.controllerWaitGroup.Wait()
tunnel.embeddedServerListWaitGroup.Wait()
psiphon.CloseDataStore()

tunnel.stop()
tunnel.stop = nil
tunnel.controllerDial = nil

// Clear our notice receiver, as it is no longer needed and we should let it be
// garbage-collected.
psiphon.SetNoticeWriter(io.Discard)
}

// Dial connects to the specified address through the Psiphon tunnel.
// It is safe to call Dial after the tunnel has been stopped.
// It is safe to call Dial concurrently with Stop.
func (tunnel *PsiphonTunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
// Ensure the dial is accessed in a thread-safe manner, without holding the lock
// while calling the dial function.
// Note that it is safe for controller.Dial to be called even after or during a tunnel
// shutdown (i.e., if the context has been canceled).
tunnel.mu.Lock()
dial := tunnel.controllerDial
tunnel.mu.Unlock()
if dial == nil {
return nil, errors.TraceNew("tunnel not started")
}
return dial(remoteAddr, nil)
}
Loading

0 comments on commit d60bc23

Please sign in to comment.