Skip to content

Commit

Permalink
dialers,protect,ipn: impl DialBind
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Oct 23, 2024
1 parent fe63bec commit bd409fd
Show file tree
Hide file tree
Showing 23 changed files with 614 additions and 295 deletions.
46 changes: 28 additions & 18 deletions intra/dialers/cdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,27 @@ func maybeFilter(ips []netip.Addr, alwaysExclude netip.Addr) ([]netip.Addr, bool
return filtered, !failingopen
}

func commondial[D rdial, C rconn](d D, network, addr string, connect dialFn[D, C]) (C, error) {
func commondial[D rdials, C rconns](d D, network, addr string, connect dialFn[D, C]) (C, error) {
return commondial2(d, network, "", addr, connect)
}

func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect dialFn[D, C]) (C, error) {
start := time.Now()

log.D("rdial: commondial: dialing (host:port) %s", addr)
domain, portstr, err := net.SplitHostPort(addr)
local, lerr := netip.ParseAddrPort(laddr) // okay if local is invalid
domain, portstr, err := net.SplitHostPort(raddr)

log.D("rdial: commondial: dialing (host:port) %s=>%s; errs? %v %v",
laddr, raddr, lerr, err)

if err != nil {
return nil, err
}

// cannot dial into a wildcard address
// while, listen is unsupported
if len(domain) == 0 {
return nil, net.InvalidAddrError(addr)
return nil, net.InvalidAddrError(raddr)
}
port, err := strconv.Atoi(portstr)
if err != nil {
Expand All @@ -79,29 +88,30 @@ func commondial[D rdial, C rconn](d D, network, addr string, connect dialFn[D, C

defer func() {
dur := time.Since(start)
log.D("rdial: duration: %s; addr %s; confirmed? %s, sz: %d", dur, addr, confirmed, ips.Size())
log.D("rdial: duration: %s; addr %s; confirmed? %s, sz: %d", dur, raddr, confirmed, ips.Size())
}()

if confirmedIPOK {
log.V("rdial: commondial: dialing confirmed ip %s for %s", confirmed, addr)
conn, err = connect(d, network, confirmed, port)
remote := netip.AddrPortFrom(confirmed, uint16(port))
log.V("rdial: commondial: dialing confirmed ip %s for %s", confirmed, remote)
conn, err = connect(d, network, local, remote)
// nilaway: tx.socks5 returns nil conn even if err == nil
if conn == nil && err == nil {
err = errNoConn
}
if err == nil {
log.V("rdial: commondial: ip %s works for %s", confirmed, addr)
log.V("rdial: commondial: ip %s works for %s", confirmed, remote)
return conn, nil
}
errs = errors.Join(errs, err)
ips.Disconfirm(confirmed)
logwd(err)("rdial: commondial: confirmed %s for %s failed; err %v",
confirmed, addr, err)
confirmed, remote, err)
}

if dontretry {
if !confirmedIPOK {
log.E("rdial: ip %s not ok for %s", confirmed, addr)
log.E("rdial: ip %s not ok for %s", confirmed, raddr)
errs = errors.Join(errs, errNoIps)
}
return nil, errs
Expand All @@ -115,32 +125,32 @@ func commondial[D rdial, C rconn](d D, network, addr string, connect dialFn[D, C
ipset = ips.Addrs()
allips, failingopen = maybeFilter(ipset, confirmed)
}
log.D("rdial: renew ips for %s; ok? %t, failingopen? %t", addr, ok, failingopen)
log.D("rdial: renew ips for %s; ok? %t, failingopen? %t", raddr, ok, failingopen)
}
log.D("rdial: commondial: trying all ips %d %v for %s, failingopen? %t",
len(allips), allips, addr, failingopen)
len(allips), allips, raddr, failingopen)
for _, ip := range allips {
end := time.Since(start)
if end > dialRetryTimeout {
log.D("rdial: commondial: timeout %s for %s", end, addr)
log.D("rdial: commondial: timeout %s for %s", end, raddr)
break
}
if ipok(ip) {
conn, err = connect(d, network, ip, port)
remote := netip.AddrPortFrom(ip, uint16(port))
conn, err = connect(d, network, local, remote)
// nilaway: tx.socks5 returns nil conn even if err == nil
if conn == nil && err == nil {
err = errNoConn
}
if err == nil {
log.V("rdial: commondial: dialing ip %s for %s", ip, addr)
confirm(ips, ip)
log.I("rdial: commondial: ip %s works for %s", ip, addr)
log.I("rdial: commondial: ip %s works for %s", ip, remote)
return conn, nil
}
errs = errors.Join(errs, err)
logwd(err)("rdial: commondial: ip %s for %s failed; err %v", ip, addr, err)
logwd(err)("rdial: commondial: ip %s for %s failed; err %v", ip, remote, err)
} else {
log.W("rdial: commondial: ip %s not ok for %s", ip, addr)
log.W("rdial: commondial: ip %s not ok for %s", ip, raddr)
}
}

Expand Down
24 changes: 0 additions & 24 deletions intra/dialers/direct_split.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (

"github.com/celzero/firestack/intra/core"
"github.com/celzero/firestack/intra/log"
"github.com/celzero/firestack/intra/protect"
"github.com/celzero/firestack/intra/settings"
)

Expand All @@ -35,29 +34,6 @@ type splitter struct {

var _ core.DuplexConn = (*splitter)(nil)

// dialWithSplitStrat returns a TCP connection that always splits the initial upstream segment
// using the specified strategy, strat, which is one of the settings.Split* constants.
func dialWithSplitStrat(dialStrat int32, d *protect.RDial, addr *net.TCPAddr) (core.DuplexConn, error) {
switch dialStrat {
case settings.SplitNever:
return d.DialTCP(addr.Network(), nil, addr)
case settings.SplitDesync:
return dialWithSplitAndDesync(d, addr.AddrPort())
case settings.SplitTCP, settings.SplitTCPOrTLS:
fallthrough
default:
}
conn, err := d.DialTCP(addr.Network(), nil, addr)
if err != nil {
return nil, err
}
if conn == nil {
return nil, errNoConn
}
// todo: strat must be tcp or tls
return &splitter{conn: conn, strat: dialStrat}, nil
}

// Write implements DuplexConn.
func (s *splitter) Write(b []byte) (n int, err error) {
if s.used.Load() {
Expand Down
9 changes: 5 additions & 4 deletions intra/dialers/pdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ import (
"golang.org/x/net/proxy"
)

func proxyConnect(d *proxy.Dialer, proto string, ip netip.Addr, port int) (net.Conn, error) {
// todo: dial bound to the local address if specified
func proxyConnect(d *proxy.Dialer, proto string, local, remote netip.AddrPort) (net.Conn, error) {
if d == nil { // unlikely
log.E("pdial: proxyConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("pdial: proxyConnect: invalid ip", ip)
} else if !ipok(remote.Addr()) {
log.E("pdial: proxyConnect: invalid ip", remote)
return nil, errNoIps
}

return (*d).Dial(proto, addrstr(ip, port))
return (*d).Dial(proto, remote.String())
}

// ProxyDial tries to connect to addr using d
Expand Down
115 changes: 82 additions & 33 deletions intra/dialers/rdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,99 @@ import (
utls "github.com/refraction-networking/utls"
)

func netConnect2(d *protect.RDialer, proto string, ip netip.Addr, port int) (net.Conn, error) {
func netConnect2(d *protect.RDialer, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) {
if d == nil {
log.E("rdial: netConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: netConnect: invalid ip", ip)
} else if !ipok(raddr.Addr()) {
log.E("rdial: netConnect: invalid ip", raddr)
return nil, errNoIps
}

return (*d).Dial(proto, addrstr(ip, port))
if laddr.IsValid() {
return (*d).DialBind(proto, laddr.String(), raddr.String())
} else {
return (*d).Dial(proto, raddr.String())
}
}

// ipConnect dials into ip:port using the provided dialer and returns a net.Conn
// net.Conn is guaranteed to be either net.UDPConn or net.TCPConn
func ipConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
func ipConnect(d *protect.RDial, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) {
if d == nil {
log.E("rdial: ipConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: ipConnect: invalid ip", ip)
} else if !ipok(raddr.Addr()) {
log.E("rdial: ipConnect: invalid ip", raddr)
return nil, errNoIps
}

switch proto {
case "tcp", "tcp4", "tcp6":
return d.DialTCP(proto, nil, tcpaddr(ip, port))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, udpaddr(ip, port))
default:
return d.Dial(proto, addrstr(ip, port))
if laddr.IsValid() {
switch proto {
case "tcp", "tcp4", "tcp6":
return d.DialTCP(proto, net.TCPAddrFromAddrPort(laddr), net.TCPAddrFromAddrPort(raddr))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, net.UDPAddrFromAddrPort(laddr), net.UDPAddrFromAddrPort(raddr))
default:
return d.DialBind(proto, laddr.String(), raddr.String())
}
} else {
switch proto {
case "tcp", "tcp4", "tcp6":
return d.DialTCP(proto, nil, net.TCPAddrFromAddrPort(raddr))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, net.UDPAddrFromAddrPort(raddr))
default:
return d.Dial(proto, raddr.String())
}
}
}

func doSplit(ip netip.Addr, port int) bool {
func doSplit(ipp netip.AddrPort) bool {
ip := ipp.Addr()
port := ipp.Port()
// HTTPS or DoT
return !ip.IsPrivate() && (port == 443 || port == 853)
}

func splitIpConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
func splitIpConnect(d *protect.RDial, proto string, laddr, raddr netip.AddrPort) (net.Conn, error) {
if d == nil {
log.E("rdial: splitIpConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: splitIpConnect: invalid ip", ip)
} else if !ipok(raddr.Addr()) {
log.E("rdial: splitIpConnect: invalid ip", raddr)
return nil, errNoIps
}

switch proto {
case "tcp", "tcp4", "tcp6":
if doSplit(ip, port) {
return DialWithSplitRetry(d, tcpaddr(ip, port))
if laddr.IsValid() {
switch proto {
case "tcp", "tcp4", "tcp6":
remote := net.TCPAddrFromAddrPort(raddr)
local := net.TCPAddrFromAddrPort(laddr)
if doSplit(raddr) {
return DialWithSplitRetry(d, local, remote)
}
return d.DialTCP(proto, local, remote)
case "udp", "udp4", "udp6":
remote := net.UDPAddrFromAddrPort(raddr)
local := net.UDPAddrFromAddrPort(laddr)
return d.DialUDP(proto, local, remote)
default:
return d.DialBind(proto, laddr.String(), raddr.String())
}
} else {
switch proto {
case "tcp", "tcp4", "tcp6":
tcpaddr := net.TCPAddrFromAddrPort(raddr)
if doSplit(raddr) {
return DialWithSplitRetry(d, nil, tcpaddr)
}
return d.DialTCP(proto, nil, tcpaddr)
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, net.UDPAddrFromAddrPort(raddr))
default:
return d.Dial(proto, raddr.String())
}
return d.DialTCP(proto, nil, tcpaddr(ip, port))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, udpaddr(ip, port))
default:
return d.Dial(proto, addrstr(ip, port))
}
}

Expand Down Expand Up @@ -117,19 +153,32 @@ func SplitDial(d *protect.RDial, network, addr string) (net.Conn, error) {
return unPtr(commondial(d, network, addr, adaptRDial(splitIpConnect)))
}

func DialBind(d *protect.RDial, network, local, remote string) (net.Conn, error) {
return unPtr(commondial2(d, network, local, remote, adaptRDial(ipConnect)))
}

func SplitDialBind(d *protect.RDial, network, local, remote string) (net.Conn, error) {
return unPtr(commondial2(d, network, local, remote, adaptRDial(splitIpConnect)))
}

// DialWithTls dials into addr using the provided dialer and returns a tls.Conn
func DialWithTls(d protect.RDialer, cfg *tls.Config, network, addr string) (net.Conn, error) {
return dialtls(&d, cfg, network, addr, adaptRDialer(netConnect2))
return dialtls(&d, cfg, network, "", addr, adaptRDialer(netConnect2))
}

// DialWithTls dials into addr using the provided dialer and returns a tls.Conn
func DialBindWithTls(d protect.RDialer, cfg *tls.Config, network, local, remote string) (net.Conn, error) {
return dialtls(&d, cfg, network, local, remote, adaptRDialer(netConnect2))
}

func dialtls[D rdial](d D, cfg *tls.Config, network, addr string, how dialFn[D, *net.Conn]) (net.Conn, error) {
c, err := unPtr(commondial(d, "tcp", addr, how))
func dialtls[D rdials](d D, cfg *tls.Config, network, local, remote string, how dialFn[D, *net.Conn]) (net.Conn, error) {
c, err := unPtr(commondial2(d, network, local, remote, how))
if err != nil {
clos(c)
return nil, err
}

tlsconn, err := tlsHello(c, cfg, addr)
tlsconn, err := tlsHello(c, cfg, remote)

if eerr := new(tls.ECHRejectionError); errors.As(err, &eerr) {
clos(tlsconn)
Expand All @@ -138,12 +187,12 @@ func dialtls[D rdial](d D, cfg *tls.Config, network, addr string, how dialFn[D,
log.I("rdial: tls: ech rejected; new? %d, err: %v", len(ech), eerr)
if len(ech) > 0 { // retry with new ech
cfg.EncryptedClientHelloConfigList = ech
c, err = unPtr(commondial(d, network, addr, how))
c, err = unPtr(commondial2(d, network, local, remote, how))
if err != nil {
clos(c)
return nil, err
}
tlsconn, err = tlsHello(c, cfg, addr)
tlsconn, err = tlsHello(c, cfg, remote)
}
}
if err != nil {
Expand Down
Loading

1 comment on commit bd409fd

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#81

Please sign in to comment.