diff --git a/discovery/srv.go b/discovery/srv.go index 34884ddcbd1..bac43ebb612 100644 --- a/discovery/srv.go +++ b/discovery/srv.go @@ -55,8 +55,8 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st return err } for _, srv := range addrs { - target := strings.TrimSuffix(srv.Target, ".") - host := net.JoinHostPort(target, fmt.Sprintf("%d", srv.Port)) + port := fmt.Sprintf("%d", srv.Port) + host := net.JoinHostPort(srv.Target, port) tcpAddr, err := resolveTCPAddr("tcp", host) if err != nil { plog.Warningf("couldn't resolve host %s during SRV discovery", host) @@ -72,8 +72,11 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st n = fmt.Sprintf("%d", tempName) tempName++ } - stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, host)) - plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, host) + // SRV records have a trailing dot but URL shouldn't. + shortHost := strings.TrimSuffix(srv.Target, ".") + urlHost := net.JoinHostPort(shortHost, port) + stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, urlHost)) + plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, urlHost) } return nil } diff --git a/discovery/srv_test.go b/discovery/srv_test.go index 4b8e2ed1e89..c90f9b682ec 100644 --- a/discovery/srv_test.go +++ b/discovery/srv_test.go @@ -17,6 +17,7 @@ package discovery import ( "errors" "net" + "strings" "testing" "github.com/coreos/etcd/pkg/testutil" @@ -29,11 +30,22 @@ func TestSRVGetCluster(t *testing.T) { }() name := "dnsClusterTest" + dns := map[string]string{ + "1.example.com.:2480": "10.0.0.1:2480", + "2.example.com.:2480": "10.0.0.2:2480", + "3.example.com.:2480": "10.0.0.3:2480", + "4.example.com.:2380": "10.0.0.3:2380", + } + srvAll := []*net.SRV{ + {Target: "1.example.com.", Port: 2480}, + {Target: "2.example.com.", Port: 2480}, + {Target: "3.example.com.", Port: 2480}, + } + tests := []struct { withSSL []*net.SRV withoutSSL []*net.SRV urls []string - dns map[string]string expected string }{ @@ -41,61 +53,50 @@ func TestSRVGetCluster(t *testing.T) { []*net.SRV{}, []*net.SRV{}, nil, - nil, "", }, { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, + srvAll, []*net.SRV{}, nil, - nil, - "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480", + "0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480", }, { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 2380}, - }, - nil, + srvAll, + []*net.SRV{{Target: "4.example.com.", Port: 2380}}, nil, - "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380", + + "0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480,3=http://4.example.com:2380", }, { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 2380}, - }, + srvAll, + []*net.SRV{{Target: "4.example.com.", Port: 2380}}, []string{"https://10.0.0.1:2480"}, - nil, - "dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380", + + "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480,2=http://4.example.com:2380", }, // matching local member with resolved addr and return unresolved hostnames { - []*net.SRV{ - {Target: "1.example.com.", Port: 2480}, - {Target: "2.example.com.", Port: 2480}, - {Target: "3.example.com.", Port: 2480}, - }, + srvAll, nil, []string{"https://10.0.0.1:2480"}, - map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"}, "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480", }, + // invalid + } + + resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { + if strings.Contains(addr, "10.0.0.") { + // accept IP addresses when resolving apurls + return net.ResolveTCPAddr(network, addr) + } + if dns[addr] == "" { + return nil, errors.New("missing dns record") + } + return net.ResolveTCPAddr(network, dns[addr]) } for i, tt := range tests { @@ -108,12 +109,6 @@ func TestSRVGetCluster(t *testing.T) { } return "", nil, errors.New("Unknown service in mock") } - resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { - if tt.dns == nil || tt.dns[addr] == "" { - return net.ResolveTCPAddr(network, addr) - } - return net.ResolveTCPAddr(network, tt.dns[addr]) - } urls := testutil.MustNewURLs(t, tt.urls) str, token, err := SRVGetCluster(name, "example.com", "token", urls) if err != nil {