diff --git a/src/internal/util/util.go b/src/internal/util/util.go index 0176877f..8c883ea7 100644 --- a/src/internal/util/util.go +++ b/src/internal/util/util.go @@ -147,3 +147,24 @@ func RemoveDuplicates[T comparable](slice []T) []T { } return result } + +// Concat returns a new slice concatenating the passed in slices. +// +// Avoids a gotcha in Go where since append modifies the underlying memory of the input slice, doing +// newSlice := append(slice1, slice2) can modify slice1. See https://go.dev/doc/effective_go#append +// A std. library concat was added in go 1.22, but this is for backwards compatibility. https://pkg.go.dev/slices#Concat +// This is mostly similiar to the std. library concat, but with a few differences so it compiles on go 1.20. +func Concat[S ~[]E, E any](slices ...S) S { + size := 0 + for _, s := range slices { + size += len(s) + if size < 0 { + panic("len out of range") + } + } + newSlice := make([]E, 0, size) + for _, s := range slices { + newSlice = append(newSlice, s...) + } + return newSlice +} diff --git a/src/internal/util/util_test.go b/src/internal/util/util_test.go index bb4a9af4..c8ad0c3f 100644 --- a/src/internal/util/util_test.go +++ b/src/internal/util/util_test.go @@ -143,3 +143,28 @@ func TestRemoveDuplicates(t *testing.T) { }) } } + +func TestConcat(t *testing.T) { + inputSlice1 := make([]int, 0, 10) + for i := 0; i < 3; i++ { + inputSlice1 = append(inputSlice1, i) + } + // Test naive append + newSlice1 := append(inputSlice1, 4) + newSlice2 := append(inputSlice1, 5) + newSlice3 := append(inputSlice1, 6) + // Shows test is working, you'd think that this would be Equal but it isn't. append() is modifying the inputSlice1 + require.NotEqual(t, []int{0, 1, 2, 4}, newSlice1) + require.NotEqual(t, []int{0, 1, 2, 5}, newSlice2) + require.Equal(t, []int{0, 1, 2, 6}, newSlice3) + // Now try with new Concat + newSlice1 = Concat(inputSlice1, []int{4}) + newSlice2 = Concat(inputSlice1, []int{5}) + newSlice3 = Concat(inputSlice1, []int{6}) + require.Len(t, inputSlice1, 3) + require.Equal(t, 10, cap(inputSlice1)) + require.Equal(t, []int{0, 1, 2}, inputSlice1) + require.Equal(t, []int{0, 1, 2, 4}, newSlice1) + require.Equal(t, []int{0, 1, 2, 5}, newSlice2) + require.Equal(t, []int{0, 1, 2, 6}, newSlice3) +} diff --git a/src/zdns/alookup.go b/src/zdns/alookup.go index 9c72ee6c..68900cf9 100644 --- a/src/zdns/alookup.go +++ b/src/zdns/alookup.go @@ -18,6 +18,8 @@ import ( "github.com/pkg/errors" "github.com/zmap/dns" + + "github.com/zmap/zdns/src/internal/util" ) // DoTargetedLookup performs a lookup of the given domain name against the given nameserver, looking up both IPv4 and IPv6 addresses @@ -57,7 +59,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod } } - combinedTrace := append(ipv4Trace, ipv6Trace...) + combinedTrace := util.Concat(ipv4Trace, ipv6Trace) // In case we get no IPs and a non-NOERROR status from either // IPv4 or IPv6 lookup, we return that status. diff --git a/src/zdns/lookup.go b/src/zdns/lookup.go index 30cd9138..eb91fed2 100644 --- a/src/zdns/lookup.go +++ b/src/zdns/lookup.go @@ -292,7 +292,7 @@ func (r *Resolver) LookupAllNameservers(q *Question, nameServer string) (*Combin for _, nserver := range nsResults.Servers { // Use all the ipv4 and ipv6 addresses of each nameserver nameserver := nserver.Name - ips := append(nserver.IPv4Addresses, nserver.IPv6Addresses...) + ips := util.Concat(nserver.IPv4Addresses, nserver.IPv6Addresses) for _, ip := range ips { curServer = net.JoinHostPort(ip, "53") res, trace, status, _ := r.ExternalLookup(q, curServer)