Skip to content

Commit

Permalink
dialers,icmp: impl Probe() for ICMP Echo
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Aug 11, 2024
1 parent 0917fd2 commit 1a403a4
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 45 deletions.
105 changes: 74 additions & 31 deletions intra/dialers/rdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,57 @@ import (
"strconv"
"time"

"github.com/celzero/firestack/intra/core"
"github.com/celzero/firestack/intra/log"
"github.com/celzero/firestack/intra/protect"
"golang.org/x/net/icmp"
)

type connectFunc func(*protect.RDial, string, netip.Addr, int) (net.Conn, error)
// rconn is a union type for net.UDPConn, net.TCPConn, icmp.PacketConn, net.TCPListener
type rconn interface {
*net.Conn | *icmp.PacketConn | *net.UDPConn | *net.TCPConn | *net.TCPListener
}

// adapt adapts a mkconn to a mkrconn
func adapt(f mkconn) mkrconn[*net.Conn] {
return func(d *protect.RDial, network string, ip netip.Addr, port int) (*net.Conn, error) {
c, err := f(d, network, ip, port)

defer func() {
if err != nil && c != nil {
clos(c)
}
}()

if err != nil {
return nil, err
}
if c == nil || core.IsNil(c) { // go.dev/play/p/SsmqM00d2oH
return nil, errNilConn
}
return &c, nil
}
}

// asConn returns a net.Conn from a *net.Conn
func asConn(c *net.Conn, err error) (net.Conn, error) {
defer func() {
if err != nil && c != nil {
clos(*c)
}
}()

if err != nil {
return nil, err
}
if c == nil || *c == nil || core.IsNil(*c) {
return nil, errNilConn
}
return *c, nil
}

type mkrconn[C rconn] func(*protect.RDial, string, netip.Addr, int) (C, error)
type mkconn func(*protect.RDial, string, netip.Addr, int) (net.Conn, error)

const dialRetryTimeout = 1 * time.Minute

Expand Down Expand Up @@ -75,19 +121,6 @@ func ipConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Con
}
}

// ipConnect2 dials into ip:port using the provided dialer and returns a net.Conn
// net.Conn may not be any among net.UDPConn or net.TCPConn or core.UDPConn or core.TCPConn
func ipConnect2(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("rdial: ipConnect2: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: ipConnect2: invalid ip", ip)
return nil, errNoIps
}
return d.Dial(proto, addrstr(ip, port))
}

func doSplit(ip netip.Addr, port int) bool {
// HTTPS or DoT
return !ip.IsPrivate() && (port == 443 || port == 853)
Expand Down Expand Up @@ -159,7 +192,19 @@ func desyncIpConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (n
}
}

func commondial(d *protect.RDial, network, addr string, connect connectFunc) (net.Conn, error) {
func tcpListen(d *protect.RDial, network string, ip netip.Addr, port int) (*net.TCPListener, error) {
return d.AcceptTCP(network, addrstr(ip, port))
}

func udpListen(d *protect.RDial, network string, ip netip.Addr, port int) (*net.UDPConn, error) {
return d.AnnounceUDP(network, addrstr(ip, port))
}

func icmpListen(d *protect.RDial, network string, ip netip.Addr, port int) (*icmp.PacketConn, error) {
return d.ProbeICMP(network, addrstr(ip, port))
}

func commondial[C rconn](d *protect.RDial, network, addr string, connect mkrconn[C]) (C, error) {
start := time.Now()

log.D("rdial: commondial: dialing (host:port) %s", addr)
Expand All @@ -177,7 +222,7 @@ func commondial(d *protect.RDial, network, addr string, connect connectFunc) (ne
return nil, err
}

var conn net.Conn
var conn C
var errs error
ips := ipm.Get(domain)
dontretry := ips.OneIPOnly() // just one IP, no retries possible
Expand Down Expand Up @@ -254,8 +299,7 @@ func ListenPacket(d *protect.RDial, network, local string) (net.PacketConn, erro
log.E("rdial: ListenPacket: nil dialer")
return nil, errNoListener
}
// todo: resolve local if hostname
return d.AnnounceUDP(network, local)
return commondial(d, network, local, udpListen)
}

// Listen listens on for TCP connections on the local address using d.
Expand All @@ -264,24 +308,23 @@ func Listen(d *protect.RDial, network, local string) (net.Listener, error) {
log.E("rdial: Listen: nil dialer")
return nil, errNoListener
}
// todo: resolve local if hostname
return d.AcceptTCP(network, local)
return commondial(d, network, local, tcpListen)
}

// Probe sends and accepts ICMP packets on addr using d over a net.PacketConn.
func Probe(d *protect.RDial, network, addr string) (net.PacketConn, error) {
return commondial(d, network, addr, icmpListen)
}

// Dial dials into addr using the provided dialer and returns a net.Conn,
// which is guaranteed to be either net.UDPConn or net.TCPConn
func Dial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, ipConnect)
}

// Dial2 dials into addr using the provided dialer and returns a net.Conn
func Dial2(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, ipConnect2)
return asConn(commondial(d, network, addr, adapt(ipConnect)))
}

// DialWithTls dials into addr using the provided dialer and returns a tls.Conn
func DialWithTls(d *protect.RDial, cfg *tls.Config, addr string) (net.Conn, error) {
c, err := commondial(d, "tcp", addr, ipConnect)
c, err := asConn(commondial(d, "tcp", addr, adapt(ipConnect)))
if err != nil {
return c, err
}
Expand All @@ -294,22 +337,22 @@ func DialWithTls(d *protect.RDial, cfg *tls.Config, addr string) (net.Conn, erro
// is unsuccessful. Using the provided dialer it returns a net.Conn,
// which may not be net.UDPConn or net.TCPConn
func SplitDial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, splitIpConnect)
return asConn(commondial(d, network, addr, adapt(splitIpConnect)))
}

// SplitAlwaysDial is like SplitDial except it splits ClientHello in all TLS connections.
func SplitAlwaysDial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, splitAlwaysIpConnect)
return asConn(commondial(d, network, addr, adapt(splitAlwaysIpConnect)))
}

// DesyncDial attempts TCP desync.
func DesyncDial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, desyncIpConnect)
return asConn(commondial(d, network, addr, adapt(desyncIpConnect)))
}

// SplitDialWithTls dials into addr using the provided dialer and returns a tls.Conn
func SplitDialWithTls(d *protect.RDial, cfg *tls.Config, addr string) (net.Conn, error) {
c, err := commondial(d, "tcp", addr, splitIpConnect)
c, err := asConn(commondial(d, "tcp", addr, adapt(splitIpConnect)))
if err != nil {
return c, err
}
Expand Down
4 changes: 1 addition & 3 deletions intra/dns53/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ func (t *transport) pxdial(network, pid string) (conn *dns.Conn, err error) {

func (t *transport) dial(network string) (*dns.Conn, error) {
// protect.dialers resolves t.addrport, if necessary
// dialers.Dial fails to dial into tcp/udp conns w/ proxies like wgproxy
// which only dial out to generic net.Conn for UDP and core.TCPConn for tcp
c, err := dialers.Dial2(t.dialer, network, t.addrport)
c, err := dialers.Dial(t.dialer, network, t.addrport)
if err != nil {
return nil, err
} else if c == nil || core.IsNil(c) {
Expand Down
28 changes: 19 additions & 9 deletions intra/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func (h *icmpHandler) Ping(source, target netip.AddrPort, msg []byte) (echoed bo
defer queueSummary(h.smmch, h.done, smm.done(err)) // err may be nil

if block {
err = errIcmpFirewalled

Check failure on line 146 in intra/icmp.go

View workflow job for this annotation

GitHub Actions / 🧭 Lint

ineffectual assignment to err (ineffassign)
log.I("t.icmp: egress: firewalled %s -> %s", source, target)
// sleep for a while to avoid busy conns? will also block netstack
// see: netstack/dispatcher.go:newReadvDispatcher
Expand All @@ -155,22 +156,24 @@ func (h *icmpHandler) Ping(source, target netip.AddrPort, msg []byte) (echoed bo
return false // denied
}

anyaddr := ":0"
dst := oneRealIp(realips, target)
uc, err := px.Dialer().Dial("udp", dst.String())
uc, err := px.Dialer().Probe("udp", anyaddr)
ucnil := uc == nil || core.IsNil(uc)
smm.Target = dst.Addr().String()
if err != nil || ucnil { // nilaway: tx.socks5 returns nil conn even if err == nil
if err == nil {
err = unix.ENETUNREACH
}
log.E("t.icmp: egress: dial(%s); hasConn? %s(%t); err %v", dst, pid, ucnil, err)
return // unhandled
return false // unhandled
}

defer clos(uc)
defer core.Close(uc)

extend(uc, icmptimeout)

_, err = uc.Write(msg)
extendp(uc, icmptimeout)
// todo: construct ICMP header? github.com/prometheus-community/pro-bing/blob/0bacb2d5e7/ping.go#L717
_, err = uc.WriteTo(msg, net.UDPAddrFromAddrPort(dst))
logei(err, "t.icmp: egress: write(%v <- %v) ping; done %d; err? %v", dst, source, len(msg), err)
if err != nil {
return false // write error
Expand All @@ -184,9 +187,10 @@ func (h *icmpHandler) Ping(source, target netip.AddrPort, msg []byte) (echoed bo
core.Recycle(bptr)
}()

extend(uc, icmptimeout)
_, err = uc.Read(b)
logei(err, "t.icmp: ingress: read(%v <- %v) ping done; err? %v", source, dst, err)
extendp(uc, icmptimeout)
_, from, err := uc.ReadFrom(b) // todo: assert from == dst
// todo: ignore non-ICMP replies in b: github.com/prometheus-community/pro-bing/blob/0bacb2d5e7/ping.go#L630
logei(err, "t.icmp: ingress: read(%v <- %v / %v) ping done; err? %v", source, from, dst, err)

return true // echoed; even if err != nil
}
Expand All @@ -197,6 +201,12 @@ func extend(c net.Conn, t time.Duration) {
}
}

func extendp(c net.PacketConn, t time.Duration) {
if c != nil && core.IsNotNil(c) {
_ = c.SetDeadline(time.Now().Add(t))
}
}

func logei(err error, msg string, args ...any) {
f := log.E
if err == nil {
Expand Down
5 changes: 3 additions & 2 deletions intra/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ const (
)

var (
errUdpFirewalled = errors.New("udp: firewalled")
errUdpSetupConn = errors.New("udp: could not create conn")
errIcmpFirewalled = errors.New("icmp: firewalled")
errUdpFirewalled = errors.New("udp: firewalled")
errUdpSetupConn = errors.New("udp: could not create conn")
)

var (
Expand Down

1 comment on commit 1a403a4

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.