diff --git a/daemoncfg/daemon_config.go b/daemoncfg/daemon_config.go index 530f2d32..26f00eee 100644 --- a/daemoncfg/daemon_config.go +++ b/daemoncfg/daemon_config.go @@ -10,6 +10,7 @@ package daemoncfg import ( "net" "os" + "strconv" "strings" log "github.com/cihub/seelog" @@ -25,6 +26,7 @@ var tcpKey = "tcp" /// string from "AWS_TRACING_DAEMON_ADDRESS" and then from recorder's configuration for DaemonAddr. /// A notation of '127.0.0.1:2000' or 'tcp:127.0.0.1:2000 udp:127.0.0.2:2001' or 'udp:127.0.0.1:2000 tcp:127.0.0.2:2001' /// are both acceptable. The first one means UDP and TCP are running at the same address. +/// Notation 'hostname:2000' or 'tcp:hostname:2000 udp:hostname:2001' or 'udp:hostname:2000 tcp:hostname:2001' are also acceptable. /// By default it assumes a X-Ray daemon running at 127.0.0.1:2000 listening to both UDP and TCP traffic. type DaemonEndpoints struct { // UDPAddr represents UDP endpoint for segments to be sent by emitter. @@ -95,9 +97,18 @@ func parseDoubleForm(addr []string) (*DaemonEndpoints, error) { addr1 := strings.Split(addr[0], ":") // tcp:127.0.0.1:2000 or udp:127.0.0.1:2000 addr2 := strings.Split(addr[1], ":") // tcp:127.0.0.1:2000 or udp:127.0.0.1:2000 - if len(addr1) != 3 || len(addr2) != 3 || net.ParseIP(addr1[1]) == nil || net.ParseIP(addr2[1]) == nil { + if len(addr1) != 3 || len(addr2) != 3 { return nil, errors.New("invalid daemon address: " + addr[0] + " " + addr[1]) } + + // validate ports + _, pErr1 := strconv.Atoi(addr1[2]) + _, pErr2 := strconv.Atoi(addr1[2]) + + if pErr1 != nil || pErr2 != nil { + return nil, errors.New("invalid daemon address port") + } + addrMap := make(map[string]string) addrMap[addr1[0]] = addr1[1] + ":" + addr1[2] @@ -124,11 +135,19 @@ func parseDoubleForm(addr []string) (*DaemonEndpoints, error) { } func parseSingleForm(addr string) (*DaemonEndpoints, error) { // format = "ip:port" - ip := strings.Split(addr, ":")[0] // get ip - if net.ParseIP(ip) == nil { + a := strings.Split(addr, ":") // 127.0.0.1:2000 + + if len(a) != 2 { return nil, errors.New("invalid daemon address: " + addr) } + // validate port + _, pErr1 := strconv.Atoi(a[1]) + + if pErr1 != nil { + return nil, errors.New("invalid daemon address port") + } + udpAddr, uErr := resolveUDPAddr(addr) if uErr != nil { return nil, uErr @@ -137,6 +156,7 @@ func parseSingleForm(addr string) (*DaemonEndpoints, error) { // format = "ip:po if tErr != nil { return nil, tErr } + return &DaemonEndpoints{ UDPAddr: udpAddr, TCPAddr: tcpAddr, diff --git a/daemoncfg/daemon_config_test.go b/daemoncfg/daemon_config_test.go index ea0309c8..bf8f738b 100644 --- a/daemoncfg/daemon_config_test.go +++ b/daemoncfg/daemon_config_test.go @@ -8,12 +8,18 @@ package daemoncfg import ( + "fmt" "os" + "strings" "testing" "github.com/stretchr/testify/assert" ) +var portErr = "invalid daemon address port" +var addrErr = "invalid daemon address" +var hostErr = "no such host" + func TestGetDaemonEndpoints1(t *testing.T) { // default address set to udp and tcp udpAddr := "127.0.0.1:2000" tcpAddr := "127.0.0.1:2000" @@ -123,6 +129,7 @@ func TestGetDaemonEndpointsFromStringInvalid1(t *testing.T) { // "udp:127.0.0.5: dEndpt, err := GetDaemonEndpointsFromString(dAddr) assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), addrErr)) assert.Nil(t, dEndpt) } @@ -133,6 +140,7 @@ func TestGetDaemonEndpointsFromStringInvalid2(t *testing.T) { // "tcp:127.0.0.5: dEndpt, err := GetDaemonEndpointsFromString(dAddr) assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), addrErr)) assert.Nil(t, dEndpt) } @@ -145,16 +153,17 @@ func TestGetDaemonEndpointsFromStringInvalid3(t *testing.T) { // env variable se dEndpt, err := GetDaemonEndpointsFromString(dAddr) assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), addrErr)) assert.Nil(t, dEndpt) } func TestGetDaemonEndpointsFromStringInvalid4(t *testing.T) { - udpAddr := "127.0.02:2001" // error in resolving address + udpAddr := "1.2.1:2a" // error in resolving address port tcpAddr := "127.0.0.1:2000" dAddr := "udp:" + udpAddr + " tcp:" + tcpAddr dEndpt, err := GetDaemonEndpointsFromString(dAddr) - + assert.True(t, strings.Contains(fmt.Sprint(err), portErr)) assert.NotNil(t, err) assert.Nil(t, dEndpt) } @@ -167,6 +176,7 @@ func TestGetDaemonEndpointsFromStringInvalid5(t *testing.T) { dEndpt, err := GetDaemonEndpointsFromString(dAddr) assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), hostErr)) assert.Nil(t, dEndpt) } @@ -176,6 +186,7 @@ func TestGetDaemonEndpointsFromStringInvalid6(t *testing.T) { dEndpt, err := GetDaemonEndpointsFromString(dAddr) assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), addrErr)) assert.Nil(t, dEndpt) } @@ -186,3 +197,60 @@ func TestGetDaemonEndpointsFromStringInvalid7(t *testing.T) { assert.Nil(t, err) assert.Nil(t, dEndpt) } + +func TestGetDaemonEndpointsForHostname1(t *testing.T) { // parsing hostname - single form + udpAddr := "127.0.0.1:2000" + tcpAddr := "127.0.0.1:2000" + udpEndpt, _ := resolveUDPAddr(udpAddr) + tcpEndpt, _ := resolveTCPAddr(tcpAddr) + dEndpt, _ := GetDaemonEndpointsFromString("localhost:2000") + + assert.Equal(t, dEndpt.UDPAddr, udpEndpt) + assert.Equal(t, dEndpt.TCPAddr, tcpEndpt) +} + +func TestGetDaemonEndpointsForHostname2(t *testing.T) { // Invalid hostname - single form + dEndpt, err := GetDaemonEndpointsFromString("XYZ:2000") + assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), hostErr)) + assert.Nil(t, dEndpt) +} + +func TestGetDaemonEndpointsForHostname3(t *testing.T) { // parsing hostname - double form + udpAddr := "127.0.0.1:2000" + tcpAddr := "127.0.0.1:2000" + udpEndpt, _ := resolveUDPAddr(udpAddr) + tcpEndpt, _ := resolveTCPAddr(tcpAddr) + dEndpt, _ := GetDaemonEndpointsFromString("tcp:localhost:2000 udp:localhost:2000") + + assert.Equal(t, dEndpt.UDPAddr, udpEndpt) + assert.Equal(t, dEndpt.TCPAddr, tcpEndpt) +} + +func TestGetDaemonEndpointsForHostname4(t *testing.T) { // Invalid hostname - double form + dEndpt, err := GetDaemonEndpointsFromString("tcp:ABC:2000 udp:XYZ:2000") + assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), hostErr)) + assert.Nil(t, dEndpt) +} + +func TestGetDaemonEndpointsForHostname5(t *testing.T) { // Invalid hostname - double form + dEndpt, err := GetDaemonEndpointsFromString("tcp:localhost:2000 tcp:localhost:2000") + assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), addrErr)) + assert.Nil(t, dEndpt) +} + +func TestGetDaemonEndpointsForHostname6(t *testing.T) { // Invalid port - single form + dEndpt, err := GetDaemonEndpointsFromString("localhost:") + assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), portErr)) + assert.Nil(t, dEndpt) +} + +func TestGetDaemonEndpointsForHostname7(t *testing.T) { // Invalid port - double form + dEndpt, err := GetDaemonEndpointsFromString("tcp:localhost:r4 tcp:localhost:2000") + assert.NotNil(t, err) + assert.True(t, strings.Contains(fmt.Sprint(err), portErr)) + assert.Nil(t, dEndpt) +}