-
Notifications
You must be signed in to change notification settings - Fork 164
/
Copy pathlinuxreachprober.go
202 lines (189 loc) · 5.93 KB
/
linuxreachprober.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Copyright (c) 2024 Zededa, Inc.
// SPDX-License-Identifier: Apache-2.0
package portprober
import (
"context"
"fmt"
"net"
"time"
"github.com/lf-edge/eve/pkg/pillar/types"
"github.com/lf-edge/eve/pkg/pillar/utils/netutils"
"github.com/tatsushid/go-fastping"
)
// HostnameAddr is an implementation of net.Addr that can be used to store hostname
// and optionally also port.
// Other implementations provided by Golang can only store already resolved IP address.
type HostnameAddr struct {
Hostname string
Port uint16
}
// Network returns the address's network type.
func (h *HostnameAddr) Network() string {
if h.Port == 0 {
return "ip"
}
// We do not provide UDP-based probing method for now.
return "tcp"
}
// String returns the address in the form "hostname:port".
func (h *HostnameAddr) String() string {
if h.Port == 0 {
return h.Hostname
}
return fmt.Sprintf("%s:%d", h.Hostname, h.Port)
}
// LinuxReachabilityProberICMP is an implementation of ReachabilityProber
// for the ICMP-based probing method and the Linux TCP/IP network stack.
type LinuxReachabilityProberICMP struct{}
// Probe reachability of <dstAddr> using ICMP ping sent via the given port.
func (p *LinuxReachabilityProberICMP) Probe(ctx context.Context, portIfName string,
srcIP net.IP, dstAddr net.Addr, dnsServers []net.IP) error {
// Do not use DNS servers other than those that belong to the probed
// network port.
customResolver := &dnsResolver{
dnsServers: dnsServers,
srcIP: srcIP,
ifName: portIfName,
}
resolver := customResolver.getNetResolver()
var dstIPs []*net.IPAddr
switch addr := dstAddr.(type) {
case *net.IPAddr:
// Resolver is not needed, dstAddr is already an IP address.
dstIPs = append(dstIPs, addr)
case *HostnameAddr:
// Try to resolve destination hostname.
ips, err := resolver.LookupIP(ctx, "ip", addr.Hostname)
if err != nil {
return fmt.Errorf("failed to resolve %s: %w", dstAddr, err)
}
if len(ips) == 0 {
return fmt.Errorf("resolver returned no IPs for %s", dstAddr)
}
for _, ip := range ips {
dstIPs = append(dstIPs, &net.IPAddr{IP: ip})
}
default:
return fmt.Errorf("unexpected dstAddr type for ICMP probe: %T", dstAddr)
}
for i, dstIP := range dstIPs {
// Determine timeout for the ping based on the context.
var pingTimeout time.Duration
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
pingTimeout = deadline.Sub(time.Now())
if pingTimeout <= 0 {
return fmt.Errorf("ping timeout expired")
}
// Leave some time to try remaining IPs.
pingTimeout = pingTimeout / time.Duration(len(dstIPs)-i)
}
var pingSuccess bool
pinger := fastping.NewPinger()
pinger.AddIPAddr(dstIP)
_, err := pinger.Source(srcIP.String())
if err != nil {
// Should be unreachable, source IP is valid.
return err
}
if pingTimeout != 0 {
pinger.MaxRTT = pingTimeout
}
pinger.OnRecv = func(ip *net.IPAddr, d time.Duration) {
if ip != nil && ip.IP.Equal(dstIP.IP) {
pingSuccess = true
}
}
err = pinger.Run()
if err != nil {
// Check remaining time and try the next IP.
continue
}
if pingSuccess {
return nil
}
}
return fmt.Errorf("no ping response received from %v", dstAddr)
}
// LinuxReachabilityProberTCP is an implementation of ReachabilityProber
// for the TCP-based probing method and the Linux TCP/IP network stack.
type LinuxReachabilityProberTCP struct{}
// Probe reachability of <dstAddr> using TCP handshake initiated via the given port.
func (p *LinuxReachabilityProberTCP) Probe(ctx context.Context, portIfName string,
srcIP net.IP, dstAddr net.Addr, dnsServers []net.IP) error {
// Do not use DNS servers other than those that belong to the probed
// network port.
customResolver := &dnsResolver{
dnsServers: dnsServers,
srcIP: srcIP,
ifName: portIfName,
}
resolver := customResolver.getNetResolver()
switch dstAddr.(type) {
case *net.TCPAddr:
// Resolver is not needed, dstAddr is already an IP address.
resolver = nil
case *HostnameAddr:
// Continue...
default:
return fmt.Errorf("unexpected dstAddr type for TCP probe: %T", dstAddr)
}
tcpDialer := &net.Dialer{
LocalAddr: &net.TCPAddr{IP: srcIP},
Resolver: resolver,
}
conn, err := tcpDialer.DialContext(ctx, "tcp", dstAddr.String())
if err != nil {
return fmt.Errorf("TCP connect request to %v failed: %w", dstAddr, err)
}
// TCP handshake succeeded.
_ = conn.Close()
return nil
}
// dnsResolver makes sure that only defined <dnsServers> are tried to resolve
// the given hostname.
type dnsResolver struct {
srcIP net.IP
dnsServers []net.IP
ifName string
}
func (r *dnsResolver) getNetResolver() *net.Resolver {
return &net.Resolver{Dial: r.resolverDial, PreferGo: true, StrictErrors: false}
}
func (r *dnsResolver) resolverDial(
ctx context.Context, network, address string) (net.Conn, error) {
dnsHost, _, err := net.SplitHostPort(address)
if err != nil {
// No port in the address.
dnsHost = address
}
dnsIP := net.ParseIP(dnsHost)
if dnsIP == nil {
return nil, fmt.Errorf("failed to parse DNS server IP address '%s'", dnsHost)
}
if dnsIP.IsLoopback() {
// 127.0.0.1:53 is tried by Golang resolver when resolv.conf does not contain
// any nameservers (see defaultNS in net/dnsconfig_unix.go).
// There is no point in looking for DNS server on the loopback interface on EVE.
return nil, &types.DNSNotAvailError{IfName: r.ifName}
}
var acceptedServer bool
for _, dnsServer := range r.dnsServers {
if netutils.EqualIPs(dnsServer, dnsIP) {
acceptedServer = true
break
}
}
if !acceptedServer {
return nil, fmt.Errorf("DNS server %s is not valid for port %s", dnsIP, r.ifName)
}
switch network {
case "udp", "udp4", "udp6":
d := net.Dialer{LocalAddr: &net.UDPAddr{IP: r.srcIP}}
return d.DialContext(ctx, network, address)
case "tcp", "tcp4", "tcp6":
d := net.Dialer{LocalAddr: &net.TCPAddr{IP: r.srcIP}}
return d.DialContext(ctx, network, address)
default:
return nil, fmt.Errorf("unsupported address type: %v", network)
}
}