Skip to content

Commit

Permalink
net/netip: add AddrPort.Compare and Prefix.Compare
Browse files Browse the repository at this point in the history
  • Loading branch information
danderson committed Aug 31, 2023
1 parent 9aaf523 commit 1551dfd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/net/netip/netip.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package netip

import (
"cmp"
"errors"
"math"
"strconv"
Expand Down Expand Up @@ -1102,6 +1103,16 @@ func MustParseAddrPort(s string) AddrPort {
// All ports are valid, including zero.
func (p AddrPort) IsValid() bool { return p.ip.IsValid() }

// Compare returns an integer comparing two AddrPorts.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// AddrPorts sort first by IP address, then port.
func (p AddrPort) Compare(p2 AddrPort) int {
if c := p.Addr().Compare(p2.Addr()); c != 0 {
return c
}
return cmp.Compare(p.Port(), p2.Port())
}

func (p AddrPort) String() string {
switch p.ip.z {
case z0:
Expand Down Expand Up @@ -1261,6 +1272,21 @@ func (p Prefix) isZero() bool { return p == Prefix{} }
// IsSingleIP reports whether p contains exactly one IP.
func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }

// Compare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// Prefixes sort first by validity (invalid before valid), then
// address family (IPv4 before IPv6), then prefix length, then
// address.
func (p Prefix) Compare(p2 Prefix) int {
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
return c
}
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
return c
}
return p.Addr().Compare(p2.Addr())
}

// ParsePrefix parses s as an IP address prefix.
// The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
// the CIDR notation defined in RFC 4632 and RFC 4291.
Expand Down
106 changes: 105 additions & 1 deletion src/net/netip/netip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
. "net/netip"
"reflect"
"slices"
"sort"
"strings"
"testing"
Expand Down Expand Up @@ -812,7 +813,7 @@ func TestAddrWellKnown(t *testing.T) {
}
}

func TestLessCompare(t *testing.T) {
func TestAddrLessCompare(t *testing.T) {
tests := []struct {
a, b Addr
want bool
Expand Down Expand Up @@ -882,6 +883,109 @@ func TestLessCompare(t *testing.T) {
}
}

func TestAddrPortCompare(t *testing.T) {
tests := []struct {
a, b AddrPort
want int
}{
{AddrPort{}, AddrPort{}, 0},
{AddrPort{}, mustIPPort("1.2.3.4:80"), -1},

{mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0},
{mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0},

{mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1},
{mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1},

{mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1},
{mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1},

{mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1},
}
for _, tt := range tests {
got := tt.a.Compare(tt.b)
if got != tt.want {
t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
}

// Also check inverse.
if got == tt.want {
got2 := tt.b.Compare(tt.a)
if want2 := -1 * tt.want; got2 != want2 {
t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
}
}
}

// And just sort.
values := []AddrPort{
mustIPPort("[::1]:80"),
mustIPPort("[::2]:80"),
AddrPort{},
mustIPPort("1.2.3.4:443"),
mustIPPort("8.8.8.8:8080"),
mustIPPort("[::1%foo]:1024"),
}
slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) })
got := fmt.Sprintf("%s", values)
want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]`
if got != want {
t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
}
}

func TestPrefixCompare(t *testing.T) {
tests := []struct {
a, b Prefix
want int
}{
{Prefix{}, Prefix{}, 0},
{Prefix{}, mustPrefix("1.2.3.0/24"), -1},

{mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0},
{mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0},

{mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1},
{mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1},

{mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1},
{mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1},

{mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1},
}
for _, tt := range tests {
got := tt.a.Compare(tt.b)
if got != tt.want {
t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
}

// Also check inverse.
if got == tt.want {
got2 := tt.b.Compare(tt.a)
if want2 := -1 * tt.want; got2 != want2 {
t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
}
}
}

// And just sort.
values := []Prefix{
mustPrefix("1.2.3.0/24"),
mustPrefix("fe90::/64"),
mustPrefix("fe80::/64"),
mustPrefix("1.2.0.0/16"),
Prefix{},
mustPrefix("fe80::/48"),
mustPrefix("1.2.0.0/24"),
}
slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) })
got := fmt.Sprintf("%s", values)
want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]`
if got != want {
t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
}
}

func TestIPStringExpanded(t *testing.T) {
tests := []struct {
ip Addr
Expand Down

0 comments on commit 1551dfd

Please sign in to comment.