Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Concat utility and use it to fix slice.append gotcha #416

Merged
merged 6 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,24 @@ func RemoveDuplicates[T comparable](slice []T) []T {
}
return result
}

// Concat returns a new slice concatenating the passed in slices.
//
// Avoids a gotcha in Go where since append modifies the underlying memory of the input slice, doing
// newSlice := append(slice1, slice2) can modify slice1. See https://go.dev/doc/effective_go#append
// A std. library concat was added in go 1.22, but this is for backwards compatibility. https://pkg.go.dev/slices#Concat
// This is mostly similiar to the std. library concat, but with a few differences so it compiles on go 1.20.
func Concat[S ~[]E, E any](slices ...S) S {
size := 0
for _, s := range slices {
size += len(s)
if size < 0 {
panic("len out of range")
}
}
newSlice := make([]E, 0, size)
for _, s := range slices {
newSlice = append(newSlice, s...)
}
return newSlice
}
25 changes: 25 additions & 0 deletions src/internal/util/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,28 @@ func TestRemoveDuplicates(t *testing.T) {
})
}
}

func TestConcat(t *testing.T) {
inputSlice1 := make([]int, 0, 10)
for i := 0; i < 3; i++ {
inputSlice1 = append(inputSlice1, i)
}
// Test naive append
newSlice1 := append(inputSlice1, 4)
newSlice2 := append(inputSlice1, 5)
newSlice3 := append(inputSlice1, 6)
// Shows test is working, you'd think that this would be Equal but it isn't. append() is modifying the inputSlice1
require.NotEqual(t, []int{0, 1, 2, 4}, newSlice1)
require.NotEqual(t, []int{0, 1, 2, 5}, newSlice2)
require.Equal(t, []int{0, 1, 2, 6}, newSlice3)
// Now try with new Concat
newSlice1 = Concat(inputSlice1, []int{4})
newSlice2 = Concat(inputSlice1, []int{5})
newSlice3 = Concat(inputSlice1, []int{6})
require.Len(t, inputSlice1, 3)
require.Equal(t, 10, cap(inputSlice1))
require.Equal(t, []int{0, 1, 2}, inputSlice1)
require.Equal(t, []int{0, 1, 2, 4}, newSlice1)
require.Equal(t, []int{0, 1, 2, 5}, newSlice2)
require.Equal(t, []int{0, 1, 2, 6}, newSlice3)
}
4 changes: 3 additions & 1 deletion src/zdns/alookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (

"github.com/pkg/errors"
"github.com/zmap/dns"

"github.com/zmap/zdns/src/internal/util"
)

// DoTargetedLookup performs a lookup of the given domain name against the given nameserver, looking up both IPv4 and IPv6 addresses
Expand Down Expand Up @@ -57,7 +59,7 @@ func (r *Resolver) DoTargetedLookup(name, nameServer string, ipMode IPVersionMod
}
}

combinedTrace := append(ipv4Trace, ipv6Trace...)
combinedTrace := util.Concat(ipv4Trace, ipv6Trace)

// In case we get no IPs and a non-NOERROR status from either
// IPv4 or IPv6 lookup, we return that status.
Expand Down
2 changes: 1 addition & 1 deletion src/zdns/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (r *Resolver) LookupAllNameservers(q *Question, nameServer string) (*Combin
for _, nserver := range nsResults.Servers {
// Use all the ipv4 and ipv6 addresses of each nameserver
nameserver := nserver.Name
ips := append(nserver.IPv4Addresses, nserver.IPv6Addresses...)
ips := util.Concat(nserver.IPv4Addresses, nserver.IPv6Addresses)
for _, ip := range ips {
curServer = net.JoinHostPort(ip, "53")
res, trace, status, _ := r.ExternalLookup(q, curServer)
Expand Down