From 8b716d7a8ad9096ca4b198ce6052eb687f94819c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 24 Aug 2020 13:43:42 -0700 Subject: [PATCH] Provide NetworkEndpoints with an NetworkInterface interface Instead of just passing the NIC ID of a NIC, pass an interface so the network endpoint may query other information about the NIC such as whether or not it is a loopback device. PiperOrigin-RevId: 328202343 --- pkg/tcpip/network/arp/arp.go | 23 +- pkg/tcpip/network/ip_test.go | 26 +- pkg/tcpip/network/ipv4/BUILD | 1 + pkg/tcpip/network/ipv4/ipv4.go | 197 +++- pkg/tcpip/network/ipv6/BUILD | 1 + pkg/tcpip/network/ipv6/icmp.go | 16 +- pkg/tcpip/network/ipv6/icmp_test.go | 16 +- pkg/tcpip/network/ipv6/ipv6.go | 303 +++++- pkg/tcpip/network/ipv6/ndp_test.go | 2 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/addressable_endpoint.go | 676 +++++++++++++ pkg/tcpip/stack/forwarder_test.go | 21 +- pkg/tcpip/stack/group_addressable_endpoint.go | 125 +++ pkg/tcpip/stack/ndp.go | 36 +- pkg/tcpip/stack/nic.go | 936 +++++------------- pkg/tcpip/stack/nic_test.go | 37 +- pkg/tcpip/stack/registration.go | 10 +- pkg/tcpip/stack/stack.go | 24 +- pkg/tcpip/stack/stack_test.go | 21 +- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/udp_test.go | 16 +- 21 files changed, 1696 insertions(+), 797 deletions(-) create mode 100644 pkg/tcpip/stack/addressable_endpoint.go create mode 100644 pkg/tcpip/stack/group_addressable_endpoint.go diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cbbe5b77f2..5f03b7b8dc 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,6 +42,8 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { + stack.AddressableEndpoint + protocol *protocol nicID tcpip.NICID linkEP stack.LinkEndpoint @@ -49,6 +51,14 @@ type endpoint struct { nud stack.NUDHandler } +func (*endpoint) Enable() *tcpip.Error { + return nil +} + +func (*endpoint) Disable() *tcpip.Error { + return nil +} + // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. func (e *endpoint) DefaultTTL() uint8 { return 0 @@ -168,13 +178,14 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ - protocol: p, - nicID: nicID, - linkEP: sender, - linkAddrCache: linkAddrCache, - nud: nud, + AddressableEndpoint: stack.NewAddressableEndpoint(), + protocol: p, + nicID: nic.ID(), + linkEP: sender, + linkAddrCache: linkAddrCache, + nud: nud, } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index e45dd17f89..79f0c58eaa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -247,10 +247,22 @@ func buildDummyStack(t *testing.T) *stack.Stack { return s } +var testNIC stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, nil, &o, buildDummyStack(t)) defer ep.Close() // Allocate and initialize the payload view. @@ -287,7 +299,7 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv4MinimumSize + 30 @@ -357,7 +369,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -418,7 +430,7 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -495,7 +507,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) defer ep.Close() // Allocate and initialize the payload view. @@ -532,7 +544,7 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv6MinimumSize + 30 @@ -611,7 +623,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(testNIC, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index d142b4ffaa..6a1b7045ef 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 55ca94268c..035e7ebf28 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -23,6 +23,7 @@ package ipv4 import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,23 +51,84 @@ const ( fragmentblockSize = 8 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + + mu struct { + sync.RWMutex + ep stack.AddressableEndpoint + gep stack.GroupAddressableEndpoint + } } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, linkEP: linkEP, dispatcher: dispatcher, protocol: p, stack: st, } + e.mu.ep = stack.NewAddressableEndpointWithLock(&e.mu) + e.mu.gep = stack.NewGroupAddressableEndpoint(e.mu.ep) + return e +} + +var ipv4BroadcastAddr = tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // Create an endpoint to receive broadcast packets on this interface. + if _, err := e.mu.ep.AddAddress(ipv4BroadcastAddr, stack.AddAddressOptions{ + Deprecated: false, + ConfigType: stack.AddressConfigStatic, + Kind: stack.Permanent, + PEB: stack.NeverPrimaryEndpoint, + }); err != nil { + return err + } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if _, err := e.mu.gep.JoinGroup(header.IPv4AllSystems); err != nil { + return err + } + + return nil +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // The NIC may have already left the multicast group. + if _, err := e.mu.gep.LeaveGroup(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + // The address may have already been removed.o + if err := e.mu.ep.RemoveAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil } // DefaultTTL is the default time-to-live value for this endpoint. @@ -80,14 +142,14 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// Capabilities implements stack.NetworkEndpoint.Capabilities. +// Capabilities implements stack.NetworkEndpoint. func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() } // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicID + return e.nic.ID() } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -452,6 +514,129 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} +// AddAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAddress(addr tcpip.AddressWithPrefix, opts stack.AddAddressOptions) (stack.AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.AddAddress(addr, opts) +} + +// RemoveAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemoveAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.RemoveAddress(addr) +} + +// HasAddress implements stack.AddressableEndpoint. +func (e *endpoint) HasAddress(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.HasAddress(addr) +} + +// PrimaryEndpoints implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryEndpoints() []stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryEndpoints() +} + +// AllEndpoints implements stack.AddressableEndpoint. +func (e *endpoint) AllEndpoints() []stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllEndpoints() +} + +// GetEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) GetEndpoint(localAddr tcpip.Address) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.GetEndpoint(localAddr) +} + +// GetAssignedEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) GetAssignedEndpoint(localAddr tcpip.Address, allowAnyInSubnet, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + + if r := e.mu.ep.GetAssignedEndpoint(localAddr, allowAnyInSubnet, allowTemp, tempPEB); r != nil { + return r + } + + eps := e.mu.ep.AllEndpoints() + for _, r := range eps { + addr := r.AddressWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(localAddr) && r.IsAssigned(allowTemp) && r.IncRef() { + return r + } + } + + return nil +} + +// PrimaryEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryEndpoint(remoteAddr tcpip.Address, spoofingOrPromiscuous bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryEndpoint(remoteAddr, spoofingOrPromiscuous) +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryAddresses() +} + +// AllAddresses implements stack.AddressableEndpoint. +func (e *endpoint) AllAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllAddresses() +} + +// RemoveAllAddresses implements stack.AddressableEndpoint. +func (e *endpoint) RemoveAllAddresses() *tcpip.Error { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.RemoveAllAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV4MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address, force bool) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.LeaveGroup(addr, force) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.gep.IsInGroup(addr) +} + +// LeaveAllGroups implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveAllGroups() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.LeaveAllGroups() +} + type protocol struct { ids []uint32 hashIV uint32 diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bcc64994e3..87a519d85c 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2b83c421e0..46cfcf1998 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -208,7 +208,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + if isTentative, err := s.IsAddrTentative(e.NICID(), targetAddr); err != nil { // We will only get an error if the NIC is unrecognized, which should not // happen. For now, drop this packet. // @@ -227,7 +227,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // stack know so it can handle such a scenario and do nothing further with // the NS. if r.RemoteAddress == header.IPv6Any { - s.DupTentativeAddrDetected(e.nicID, targetAddr) + s.DupTentativeAddrDetected(e.NICID(), targetAddr) } // Do not handle neighbor solicitations targeted to an address that is @@ -240,7 +240,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + if s.CheckLocalAddress(e.NICID(), ProtocolNumber, targetAddr) == 0 { return } @@ -275,7 +275,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if e.nud != nil { e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.NICID(), r.RemoteAddress, sourceLinkAddr) } // ICMPv6 Neighbor Solicit messages are always sent to @@ -354,7 +354,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme targetAddr := na.TargetAddress() s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + if isTentative, err := s.IsAddrTentative(e.NICID(), targetAddr); err != nil { // We will only get an error if the NIC is unrecognized, which should not // happen. For now short-circuit this packet. // @@ -365,7 +365,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - s.DupTentativeAddrDetected(e.nicID, targetAddr) + s.DupTentativeAddrDetected(e.NICID(), targetAddr) return } @@ -395,7 +395,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // address cache with the link address for the target of the message. if len(targetLinkAddr) != 0 { if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(e.NICID(), targetAddr, targetLinkAddr) return } @@ -568,7 +568,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Tell the NIC to handle the RA. stack := r.Stack() - stack.HandleNDPRA(e.nicID, routerAddr, ra) + stack.HandleNDPRA(e.NICID(), routerAddr, ra) case header.ICMPv6RedirectMsg: // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8112ed0518..da074edb03 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -102,6 +102,18 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { } +var testNIC stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -149,7 +161,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(testNIC, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) defer ep.Close() r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) @@ -287,7 +299,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(testNIC, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) defer ep.Close() r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 36fbbebf09..ecaf13b947 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -22,8 +22,10 @@ package ipv6 import ( "fmt" + "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -44,14 +46,53 @@ const ( DefaultTTL = 64 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + + mu struct { + sync.RWMutex + ep stack.AddressableEndpoint + gep stack.GroupAddressableEndpoint + } +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if _, err := e.mu.gep.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + return nil +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // The NIC may have already left the multicast group. + if _, err := e.mu.gep.LeaveGroup(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil } // DefaultTTL is the default hop limit for this endpoint. @@ -67,10 +108,10 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicID + return e.nic.ID() } -// Capabilities implements stack.NetworkEndpoint.Capabilities. +// Capabilities implements stack.NetworkEndpoint. func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() } @@ -426,6 +467,253 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } +// AddAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAddress(addr tcpip.AddressWithPrefix, opts stack.AddAddressOptions) (stack.AddressEndpoint, *tcpip.Error) { + // TODO: add checks here after making sure b/140943433 won't happen. + + e.mu.Lock() + defer e.mu.Unlock() + + nep, err := e.mu.ep.AddAddress(addr, opts) + if err != nil { + return nil, err + } + + if !header.IsV6UnicastAddress(addr.Address) { + return nep, nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.gep.JoinGroup(snmc); err != nil { + return nil, err + } + + return nep, nil +} + +// RemoveAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemoveAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.removeAddressLocked(addr) +} + +func (e *endpoint) removeAddressLocked(addr tcpip.Address) *tcpip.Error { + if err := e.mu.ep.RemoveAddress(addr); err != nil { + return err + } + + if !header.IsV6UnicastAddress(addr) { + return nil + } + + snmc := header.SolicitedNodeAddr(addr) + if _, err := e.mu.gep.LeaveGroup(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil +} + +// HasAddress implements stack.AddressableEndpoint. +func (e *endpoint) HasAddress(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.HasAddress(addr) +} + +// PrimaryEndpoints implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryEndpoints() []stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryEndpoints() +} + +// AllEndpoints implements stack.AddressableEndpoint. +func (e *endpoint) AllEndpoints() []stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllEndpoints() +} + +// GetEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) GetEndpoint(localAddr tcpip.Address) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.GetEndpoint(localAddr) +} + +// GetAssignedEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) GetAssignedEndpoint(localAddr tcpip.Address, allowAnyInSubnet, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.GetAssignedEndpoint(localAddr, allowAnyInSubnet, allowTemp, tempPEB) +} + +// PrimaryEndpoint implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryEndpoint(remoteAddr tcpip.Address, spoofingOrPromiscuous bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + + // ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC + // 6724 section 5). + type ipv6AddrCandidate struct { + ref stack.AddressEndpoint + scope header.IPv6AddressScope + } + + if len(remoteAddr) == 0 { + return e.mu.ep.PrimaryEndpoint(remoteAddr, spoofingOrPromiscuous) + } + + primaryAddrs := e.mu.ep.PrimaryEndpoints() + + if len(primaryAddrs) == 0 { + return nil + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) + for _, r := range primaryAddrs { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !r.IsAssigned(spoofingOrPromiscuous) { + continue + } + + addr := r.AddressWithPrefix().Address + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, ipv6AddrCandidate{ + ref: r, + scope: scope, + }) + } + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.ref.AddressWithPrefix().Address == remoteAddr { + return true + } + if sb.ref.AddressWithPrefix().Address == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.ref.Deprecated(), sb.ref.Deprecated(); saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.ref.ConfigType() == stack.AddressConfigSlaacTemp, sb.ref.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if r := c.ref; r.IncRef() { + return r + } + } + + return nil +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryAddresses() +} + +// AllAddresses implements stack.AddressableEndpoint. +func (e *endpoint) AllAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllAddresses() +} + +// RemoveAllAddresses implements stack.AddressableEndpoint. +func (e *endpoint) RemoveAllAddresses() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + var err *tcpip.Error + for _, r := range e.mu.ep.AllEndpoints() { + switch r.GetKind() { + case stack.PermanentTentative, stack.Permanent: + if tempErr := e.removeAddressLocked(r.AddressWithPrefix().Address); tempErr != nil && err == nil { + err = tempErr + } + } + } + return err +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV6MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address, force bool) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.LeaveGroup(addr, force) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.gep.IsInGroup(addr) +} + +// LeaveAllGroups implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveAllGroups() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.gep.LeaveAllGroups() +} + type protocol struct { // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful and it must be accessed @@ -456,9 +744,9 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, linkEP: linkEP, linkAddrCache: linkAddrCache, nud: nud, @@ -466,6 +754,9 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddres protocol: p, stack: st, } + e.mu.ep = stack.NewAddressableEndpointWithLock(&e.mu) + e.mu.gep = stack.NewGroupAddressableEndpoint(e.mu.ep) + return e } // SetOption implements NetworkProtocol.SetOption. diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 480c495fa0..7c7d5083f3 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -65,7 +65,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(testNIC, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) return s, ep } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 900938dd19..a2a59a9f3c 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,9 +54,11 @@ go_template_instance( go_library( name = "stack", srcs = [ + "addressable_endpoint.go", "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "group_addressable_endpoint.go", "headertype_string.go", "icmp_rate_limit.go", "iptables.go", diff --git a/pkg/tcpip/stack/addressable_endpoint.go b/pkg/tcpip/stack/addressable_endpoint.go new file mode 100644 index 0000000000..f5205c2773 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint.go @@ -0,0 +1,676 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// AddressableEndpoint is an endpoint that supports addressing. +type AddressableEndpoint interface { + // AddAddress adds the specifid address with the given options. + // + // Returns a AddressEndpoint for the added address with a ref count + // of 1. If a Temporary addrss was added, it must be used without incrmenting + // the AddressEndpoint's ref count. Other endpoints' ref count will + // need to be incremented. + AddAddress(addr tcpip.AddressWithPrefix, opts AddAddressOptions) (AddressEndpoint, *tcpip.Error) + + // RemoveAddress removes a permanent address. + RemoveAddress(addr tcpip.Address) *tcpip.Error + + // HasAddress returns true if the endpoint has the specified permanent + // address. + HasAddress(addr tcpip.Address) bool + + // PrimaryEndpoints returns all the primary endpoints. + PrimaryEndpoints() []AddressEndpoint + + // AllEndpoints returns all the endpoints. + AllEndpoints() []AddressEndpoint + + // GetEndpoint returns an endpoint for the specified local address. + // + // Returns nil if the specified address is not local to this endpoint. + GetEndpoint(localAddr tcpip.Address) AddressEndpoint + + // GetAssignedEndpoint returns an assigned endpoint for the specified local + // address, optionally creating a temporary endpoint if requested. + // + // Returns nil if the specified address is not local to this endpoint. + GetAssignedEndpoint(localAddr tcpip.Address, allowAnyInSubnet, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint + + // PrimaryEndpoint returns a primary endpooint to use when communicating with + // the specified remote address. + PrimaryEndpoint(remoteAddr tcpip.Address, spoofingOrPromiscuous bool) AddressEndpoint + + // PrimaryAddresses returns the primary addresses. + PrimaryAddresses() []tcpip.AddressWithPrefix + + // AllAddresses returns all the addresses. + AllAddresses() []tcpip.AddressWithPrefix + + // RemoveAllAddresses removes all permanent addresses. + RemoveAllAddresses() *tcpip.Error +} + +// AddressConfigType is the way an address was configured. +type AddressConfigType int32 + +const ( + // AddressConfigStatic is a statically configured address endpoint that was + // added by some user-specified action (adding an explicit address, joining a + // multicast group). + AddressConfigStatic AddressConfigType = iota + + // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862 + // section 5.5.3. + AddressConfigSlaac + + // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as + // per RFC 4941. Temporary SLAAC addresses are short-lived and are not + // to be valid (or preferred) forever; hence the term temporary. + AddressConfigSlaacTemp +) + +// AddAddressOptions are options when adding an address. +type AddAddressOptions struct { + Deprecated bool + ConfigType AddressConfigType + Kind AddressKind + PEB PrimaryEndpointBehavior +} + +// AddressEndpoint is an endpoint representing an address assigned to an +// AddressableEndpoint. +type AddressEndpoint interface { + // AddressWithPrefix returns the endpoint's address. + AddressWithPrefix() tcpip.AddressWithPrefix + + // IsAssigned returns whether or not th endpoint is considered bound + // to its AddressableEndpoint. + IsAssigned(spoofingOrPromiscuous bool) bool + + // GetKind returns the AddressKind for this endpoint. + GetKind() AddressKind + + // SetKind sets the AddressKind for this endpoint. + SetKind(AddressKind) + + // IncRef increments this endpoint's reference count. + // + // Returns true if it was successfully incremented. If it returns false, then + // the endpoint is considered expired and should no longer be used. + IncRef() bool + + // DecRef decrements this endpoint's reference count. + // + // If it returns true, then the endpoint has been released and must no longer + // be used. + DecRef() bool + + // ConfigType returns the method used to add this endpoint to its + // AddressableEndpoint. + ConfigType() AddressConfigType + + // Deprecated returns whether or not this endpoint is deprecated. + Deprecated() bool + + // SetDeprecated sets this endpoint's deprecated status. + SetDeprecated(bool) +} + +// AddressKind is the kind of of an address. +// +// See the values of AddressKind for more details. +type AddressKind int32 + +const ( + // PermanentTentative is a permanent address endpoint that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses will be of this kind until + // NDP's Duplicate Address Detection has resolved, or be deleted if + // the process results in detecting a duplicate address. + PermanentTentative AddressKind = iota + + // Permanent is a permanent endpoint (vs. a temporary one) assigned to the + // NIC. Its reference count is biased by 1 to avoid removal when no route + // holds a reference to it. It is removed by explicitly removing the address + // from the NIC. + Permanent + + // PermanentExpired is a permanent endpoint that had its address removed from + // the NIC, and it is waiting to be removed once no references to it are held. + // + // If the address is re-added before the endpoint is removed, its type + // changes back to Permanent. + PermanentExpired + + // Temporary is an endpoint, created on a one-off basis to temporarily + // consider the NIC bound an an address that it is not explictiy bound to + // (such as a permanent address). Its reference count must not be biased by 1 + // so that the address is removed immediately when references to it are no + // longer held. + // + // A temporary endpoint may be promoted to permanent if the address is added + // permanently. + Temporary +) + +// NewAddressableEndpoint returns an AddressableEndpoint that is protected by a +// lock. +// +// Useful when specialization of an AddressableEndpoint is not required. +func NewAddressableEndpoint() AddressableEndpoint { + l := &lockedAddressableEndpointState{} + l.mu.ep = addressableEndpointState{ + lock: &l.mu, + endpoints: make(map[tcpip.Address]*addressState), + } + return l +} + +// NewAddressableEndpointWithLock returns an AddressableEndpoint that requires +// the specified lock to be held before calling any methods on itself. +// +// The returned AddressableEndpoint will not obtain the lock before doing any +// work. +// +// Useful when an implementation may want to specialize some functions of the +// AddressableEndpoint. +func NewAddressableEndpointWithLock(lock sync.Locker) AddressableEndpoint { + return &addressableEndpointState{ + lock: lock, + endpoints: make(map[tcpip.Address]*addressState), + } +} + +var _ AddressableEndpoint = (*lockedAddressableEndpointState)(nil) + +// lockedAddressableEndpointState is an implementation of AddressableEndpoint +// that protects an inner AddressableEndpoint with a mutex. +type lockedAddressableEndpointState struct { + mu struct { + sync.RWMutex + ep addressableEndpointState + } +} + +// AddAddress implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) AddAddress(addr tcpip.AddressWithPrefix, opts AddAddressOptions) (AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.AddAddress(addr, opts) +} + +// RemoveAddress implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) RemoveAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.RemoveAddress(addr) +} + +// HasAddress implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) HasAddress(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.HasAddress(addr) +} + +// PrimaryEndpoints implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) PrimaryEndpoints() []AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryEndpoints() +} + +// AllEndpoints implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) AllEndpoints() []AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllEndpoints() +} + +// GetEndpoint implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) GetEndpoint(localAddr tcpip.Address) AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.GetEndpoint(localAddr) +} + +// GetAssignedEndpoint implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) GetAssignedEndpoint(localAddr tcpip.Address, allowAnyInSubnet, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.GetAssignedEndpoint(localAddr, allowAnyInSubnet, allowTemp, tempPEB) +} + +// PrimaryEndpoint implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) PrimaryEndpoint(remoteAddr tcpip.Address, spoofingOrPromiscuous bool) AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryEndpoint(remoteAddr, spoofingOrPromiscuous) +} + +// PrimaryAddresses implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.PrimaryAddresses() +} + +// AllAddresses implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) AllAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.ep.AllAddresses() +} + +// RemoveAllAddresses implements AddressableEndpoint. +func (e *lockedAddressableEndpointState) RemoveAllAddresses() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.ep.RemoveAllAddresses() +} + +var _ AddressableEndpoint = (*addressableEndpointState)(nil) + +// addressableEndpointState is an implementation of an AddressableEndpoint that +// does not perform any locking before doing work defined by +// AddressableEndpoint. +type addressableEndpointState struct { + lock sync.Locker + endpoints map[tcpip.Address]*addressState + primary []*addressState +} + +func (s *addressableEndpointState) takeLockAndReleaseAddressState(addrState *addressState) { + s.lock.Lock() + defer s.lock.Unlock() + s.releaseAddressState(addrState) +} + +// releaseAddressState removes addrState from s's address state (primary and endpoints list). +func (s *addressableEndpointState) releaseAddressState(addrState *addressState) { + oldPrimary := s.primary + for i, a := range s.primary { + if a == addrState { + s.primary = append(s.primary[:i], s.primary[i+1:]...) + oldPrimary[len(oldPrimary)-1] = nil + break + } + } + delete(s.endpoints, addrState.addr.Address) +} + +// AddAddress implements AddressableEndpoint. +func (s *addressableEndpointState) AddAddress(addr tcpip.AddressWithPrefix, opts AddAddressOptions) (AddressEndpoint, *tcpip.Error) { + addToPrimary := func(addrState *addressState, peb PrimaryEndpointBehavior) { + switch peb { + case CanBePrimaryEndpoint: + s.primary = append(s.primary, addrState) + case FirstPrimaryEndpoint: + s.primary = append([]*addressState{addrState}, s.primary...) + } + } + + if addrState, ok := s.endpoints[addr.Address]; ok { + // Address already exists. + if opts.Kind != Permanent { + return nil, tcpip.ErrDuplicateAddress + } + + switch addrState.GetKind() { + case PermanentTentative, Permanent: + return nil, tcpip.ErrDuplicateAddress + case PermanentExpired, Temporary: + if addrState.IncRef() { + addrState.SetKind(Permanent) + addrState.deprecated = opts.Deprecated + addrState.configType = opts.ConfigType + + for i, a := range s.primary { + if a == addrState { + switch opts.PEB { + case CanBePrimaryEndpoint: + return addrState, nil + case FirstPrimaryEndpoint: + if i == 0 { + return addrState, nil + } + s.primary = append(s.primary[:i], s.primary[i+1:]...) + case NeverPrimaryEndpoint: + s.primary = append(s.primary[:i], s.primary[i+1:]...) + return addrState, nil + } + } + } + + addToPrimary(addrState, opts.PEB) + + return addrState, nil + } + + s.releaseAddressState(addrState) + } + } + + addrState := &addressState{ + networkState: s, + addr: addr, + refs: 1, + kind: opts.Kind, + configType: opts.ConfigType, + deprecated: opts.Deprecated, + } + + s.endpoints[addr.Address] = addrState + addToPrimary(addrState, opts.PEB) + + return addrState, nil +} + +// RemoveAddress implements AddressableEndpoint. +func (s *addressableEndpointState) RemoveAddress(addr tcpip.Address) *tcpip.Error { + addrState, ok := s.endpoints[addr] + if !ok { + return tcpip.ErrBadLocalAddress + } + + if kind := addrState.GetKind(); kind != Permanent && kind != PermanentTentative { + return tcpip.ErrBadLocalAddress + } + + addrState.SetKind(PermanentExpired) + s.decAddressRef(addrState) + + return nil +} + +func (s *addressableEndpointState) decAddressRef(addrState *addressState) { + if addrState.decRef() { + s.releaseAddressState(addrState) + } +} + +// HasAddress implements AddressableEndpoint. +func (s *addressableEndpointState) HasAddress(addr tcpip.Address) bool { + addrState, ok := s.endpoints[addr] + if !ok { + return false + } + + kind := addrState.GetKind() + return kind == Permanent || kind == PermanentTentative +} + +// AllEndpoints implements AddressableEndpoint. +func (s *addressableEndpointState) AllEndpoints() []AddressEndpoint { + eps := make([]AddressEndpoint, 0, len(s.endpoints)) + for _, e := range s.endpoints { + eps = append(eps, e) + } + return eps +} + +// PrimaryEndpoints implements AddressableEndpoint. +func (s *addressableEndpointState) PrimaryEndpoints() []AddressEndpoint { + eps := make([]AddressEndpoint, 0, len(s.primary)) + for _, e := range s.primary { + eps = append(eps, e) + } + return eps +} + +// GetEndpoint implements AddressableEndpoint. +func (s *addressableEndpointState) GetEndpoint(localAddr tcpip.Address) AddressEndpoint { + if r, ok := s.endpoints[localAddr]; ok && r.GetKind() != PermanentExpired { + return r + } + + return nil +} + +// GetAssignedEndpoint implements AddressableEndpoint. +func (s *addressableEndpointState) GetAssignedEndpoint(localAddr tcpip.Address, allowAnyInSubnet, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + if r, ok := s.endpoints[localAddr]; ok { + if !r.IsAssigned(allowTemp) { + return nil + } + + if r.IncRef() { + return r + } + + s.releaseAddressState(r) + } + + if !allowTemp && allowAnyInSubnet { + for _, r := range s.endpoints { + if r.GetKind() == PermanentExpired { + continue + } + + subnet := r.AddressWithPrefix().Subnet() + if subnet.Contains(localAddr) { + allowTemp = true + break + } + } + } + + if !allowTemp { + return nil + } + + r, _ := s.AddAddress(tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: len(localAddr) * 8, + }, AddAddressOptions{ + Deprecated: false, + ConfigType: AddressConfigStatic, + Kind: Temporary, + PEB: tempPEB, + }) + return r +} + +// PrimaryEndpoint implements AddressableEndpoint. +func (s *addressableEndpointState) PrimaryEndpoint(remoteAddr tcpip.Address, spoofingOrPromiscuous bool) AddressEndpoint { + var deprecatedEndpoint *addressState + for _, r := range s.primary { + if !r.IsAssigned(spoofingOrPromiscuous) { + continue + } + + if !r.Deprecated() { + if r.IncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + s.decAddressRef(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r + } + } + + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + if deprecatedEndpoint == nil { + return nil + } + return deprecatedEndpoint +} + +// PrimaryAddresses implements AddressableEndpoint. +func (s *addressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + var addrs []tcpip.AddressWithPrefix + for _, r := range s.primary { + // Don't include tentative, expired or tempory endpoints + // to avoid confusion and prevent the caller from using + // those. + switch r.GetKind() { + case PermanentTentative, PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, r.AddressWithPrefix()) + } + + return addrs +} + +// AllAddresses implements AddressableEndpoint. +func (s *addressableEndpointState) AllAddresses() []tcpip.AddressWithPrefix { + var addrs []tcpip.AddressWithPrefix + for _, r := range s.endpoints { + // Don't include tentative, expired or tempory endpoints + // to avoid confusion and prevent the caller from using + // those. + switch r.GetKind() { + case PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, r.AddressWithPrefix()) + } + + return addrs +} + +// RemoveAllAddresses implements AddressableEndpoint. +func (s *addressableEndpointState) RemoveAllAddresses() *tcpip.Error { + var err *tcpip.Error + for a, r := range s.endpoints { + switch r.GetKind() { + case PermanentTentative, Permanent: + if tempErr := s.RemoveAddress(a); tempErr != nil && err == nil { + err = tempErr + } + } + } + return err +} + +var _ AddressEndpoint = (*addressState)(nil) + +// addressState holds state for an address. +type addressState struct { + networkState *addressableEndpointState + addr tcpip.AddressWithPrefix + refs int32 + + kind AddressKind + configType AddressConfigType + deprecated bool +} + +// AddressWithPrefix implements AddressEndpoint. +func (s *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { + return s.addr +} + +// GetKind implements AddressEndpoint. +func (s *addressState) GetKind() AddressKind { + return AddressKind(atomic.LoadInt32((*int32)(&s.kind))) +} + +// SetKind implements AddressEndpoint. +func (s *addressState) SetKind(kind AddressKind) { + atomic.StoreInt32((*int32)(&s.kind), int32(kind)) +} + +// IsAssigned implements AddressEndpoint. +func (s *addressState) IsAssigned(spoofingOrPromiscuous bool) bool { + switch s.GetKind() { + case PermanentTentative: + return false + case PermanentExpired: + return spoofingOrPromiscuous + default: + return true + } +} + +// IncRef implements AddressEndpoint. +func (s *addressState) IncRef() bool { + for { + v := atomic.LoadInt32(&s.refs) + if v == 0 { + return false + } + + if atomic.CompareAndSwapInt32(&s.refs, v, v+1) { + return true + } + } +} + +// DecRef implements AddressEndpoint. +func (s *addressState) DecRef() bool { + if s.decRef() { + s.networkState.takeLockAndReleaseAddressState(s) + return true + } + + return false +} + +func (s *addressState) decRef() bool { + return atomic.AddInt32(&s.refs, -1) == 0 +} + +// ConfigType implements AddressEndpoint. +func (s *addressState) ConfigType() AddressConfigType { + // Currently this is protected by the NIC lock. + // TODO: protect this with the network endpoint lock once the NIC stops + // writing to this. + return s.configType +} + +// SetDeprecated implements AddressEndpoint. +func (s *addressState) SetDeprecated(d bool) { + // Currently this is protected by the NIC lock. + // TODO: protect this with the network endpoint lock once the NIC stops + // writing to this. + s.deprecated = d +} + +// Deprecated implements AddressEndpoint. +func (s *addressState) Deprecated() bool { + // Currently this is protected by the NIC lock. + // TODO: protect this with the network endpoint lock once the NIC stops + // writing to this. + return s.deprecated +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 91165ebc7e..4f551a8aa8 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -45,6 +45,8 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { + AddressableEndpoint + nicID tcpip.NICID proto *fwdTestNetworkProtocol dispatcher TransportDispatcher @@ -53,6 +55,14 @@ type fwdTestNetworkEndpoint struct { var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil +} + +func (*fwdTestNetworkEndpoint) Disable() *tcpip.Error { + return nil +} + func (f *fwdTestNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) } @@ -145,12 +155,13 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { return &fwdTestNetworkEndpoint{ - nicID: nicID, - proto: f, - dispatcher: dispatcher, - ep: ep, + AddressableEndpoint: NewAddressableEndpoint(), + nicID: nic.ID(), + proto: f, + dispatcher: dispatcher, + ep: ep, } } diff --git a/pkg/tcpip/stack/group_addressable_endpoint.go b/pkg/tcpip/stack/group_addressable_endpoint.go new file mode 100644 index 0000000000..16b773cd98 --- /dev/null +++ b/pkg/tcpip/stack/group_addressable_endpoint.go @@ -0,0 +1,125 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +// GroupAddressableEndpoint is an endpoint that supports group addressing. +// +// This endpoint is expected to reference count joins so that a group is only +// left once each join is matched with a leave. +type GroupAddressableEndpoint interface { + // JoinGroup joins the spcified group. + // + // If the endoint is already a member of the group, the group's join count + // will be incremented. + // + // Returns true if the group was newly joined. + JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + + // LeaveGroup decrements the join count and leaves the specified group once + // the join count reaches 0. + // + // If force is true, the group will be immediately left, even if there are + // outstanding joins. + // + // Returns true if the group was left (join count hit 0). + LeaveGroup(group tcpip.Address, force bool) (bool, *tcpip.Error) + + // IsInGroup returns true if the endpoint is a member of the specified group. + IsInGroup(group tcpip.Address) bool + + // LeaveAllGroups forcefully leaves all groups. + LeaveAllGroups() *tcpip.Error +} + +// NewGroupAddressableEndpoint returns a new GroupAddressableEndpoint that +// depends on an AddressableEndpoint to join groups. +// +// The returned GroupAddressableEndpoint does not obtain any locks before +// modifying any state. If locking is required callers must do so before +// invoking methods on the returned endpoint. +func NewGroupAddressableEndpoint(addressableEndpoint AddressableEndpoint) GroupAddressableEndpoint { + return &groupAddressableEndpointState{ + joins: make(map[tcpip.Address]uint32), + addressableEndpoint: addressableEndpoint, + } +} + +var _ GroupAddressableEndpoint = (*groupAddressableEndpointState)(nil) + +type groupAddressableEndpointState struct { + joins map[tcpip.Address]uint32 + addressableEndpoint AddressableEndpoint +} + +// JoinGroup implements GroupAddressableEndpoint.JoinGroup. +func (s *groupAddressableEndpointState) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + // TODO: don't add groups like a normal address. + joins := s.joins[addr] + if joins == 0 { + _, err := s.addressableEndpoint.AddAddress(tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: len(addr) * 8, + }, AddAddressOptions{ + Deprecated: false, + ConfigType: AddressConfigStatic, + Kind: Permanent, + PEB: NeverPrimaryEndpoint, + }) + if err != nil { + return false, err + } + } + + s.joins[addr] = joins + 1 + return joins == 0, nil +} + +// LeaveGroup implements GroupAddressableEndpoint.LeaveGroup. +func (s *groupAddressableEndpointState) LeaveGroup(addr tcpip.Address, force bool) (bool, *tcpip.Error) { + joins, ok := s.joins[addr] + if !ok { + return false, tcpip.ErrBadLocalAddress + } + + s.joins[addr] = joins - 1 + if force || joins == 1 { + if err := s.addressableEndpoint.RemoveAddress(addr); err != nil { + return false, err + } + delete(s.joins, addr) + } + + return force || joins == 1, nil +} + +// IsInGroup implements GroupAddressableEndpoint.IsInGroup. +func (s *groupAddressableEndpointState) IsInGroup(addr tcpip.Address) bool { + return s.joins[addr] != 0 +} + +// LeaveAllGroups implements GroupAddressableEndpoint.LeaveAllGroups. +func (s *groupAddressableEndpointState) LeaveAllGroups() *tcpip.Error { + var errRet *tcpip.Error + for g := range s.joins { + if _, err := s.LeaveGroup(g, true /* force */); err != nil && errRet == nil { + errRet = err + } + } + return errRet +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index b0873d1af7..d41842dd43 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -604,7 +604,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return tcpip.ErrAddressFamilyNotSupported } - if ref.getKind() != permanentTentative { + if ref.getKind() != PermanentTentative { // The endpoint should be marked as tentative since we are starting DAD. panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } @@ -623,7 +623,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { - ref.setKind(permanent) + ref.setKind(Permanent) // Consider DAD to have resolved even if no DAD messages were actually // transmitted. @@ -652,7 +652,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return } - if ref.getKind() != permanentTentative { + if ref.getKind() != PermanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) @@ -663,7 +663,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var err *tcpip.Error if !dadDone { // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, true /* createTemp */, NeverPrimaryEndpoint) // Do not hold the lock when sending packets which may be a long running // task or may block link address resolution. We know this is safe @@ -684,7 +684,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref if dadDone { // DAD has resolved. - ref.setKind(permanent) + ref.setKind(Permanent) } else if err == nil { // DAD is not done and we had no errors when sending the last NDP NS, // schedule the next DAD timer. @@ -704,7 +704,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // If DAD resolved for a stable SLAAC address, attempt generation of a // temporary SLAAC address. - if dadDone && ref.configType == slaac { + if dadDone && ref.configType() == AddressConfigSlaac { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) @@ -1189,7 +1189,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddr.ref.getKind() == permanent { + if state.stableAddr.ref.getKind() == Permanent { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) @@ -1201,7 +1201,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // addSLAACAddr adds a SLAAC address to the NIC. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { +func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType AddressConfigType, deprecated bool) *referencedNetworkEndpoint { // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { @@ -1218,7 +1218,7 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo AddressWithPrefix: addr, } - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, Permanent, configType, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) } @@ -1298,7 +1298,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { + if ref := ndp.addSLAACAddr(generatedAddr, AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { state.stableAddr.ref = ref state.generationAttempts++ return true @@ -1410,7 +1410,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary // address with a zero preferred lifetime. The checks above ensure this // so we know the address is not deprecated. - ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) + ref := ndp.addSLAACAddr(generatedAddr, AddressConfigSlaacTemp, false /* deprecated */) if ref == nil { return false } @@ -1503,7 +1503,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if deprecated { ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) } else { - prefixState.stableAddr.ref.deprecated = false + prefixState.stableAddr.ref.setDeprecated(false) } // If prefix was preferred for some finite lifetime before, cancel the @@ -1565,7 +1565,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // If DAD is not yet complete on the stable address, there is no need to do // work with temporary addresses. - if prefixState.stableAddr.ref.getKind() != permanent { + if prefixState.stableAddr.ref.getKind() != Permanent { return } @@ -1610,7 +1610,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { - tempAddrState.ref.deprecated = false + tempAddrState.ref.setDeprecated(false) tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } @@ -1654,11 +1654,11 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { - if ref.deprecated { + if ref.deprecated() { return } - ref.deprecated = true + ref.setDeprecated(true) if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) } @@ -1861,9 +1861,9 @@ func (ndp *ndpState) startSolicitingRouters() { // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryEndpointRLocked(header.IPv6ProtocolNumber, header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, true /* createTemp */, NeverPrimaryEndpoint) } ndp.nic.mu.Unlock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c811efdb0..c199a5a8ad 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,11 +15,8 @@ package stack import ( - "fmt" "math/rand" "reflect" - "sort" - "sync/atomic" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -28,14 +25,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -var ipv4BroadcastAddr = tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 8 * header.IPv4AddressSize, - }, +// NetworkInterface is an interface that can be used by a NetworkEndpoint +type NetworkInterface interface { + // ID returns the NetworkInterface's ID. + ID() tcpip.NICID + + // IsLoopback returns true if the NetworkInterface is a loopback interface. + IsLoopback() bool } +var _ NetworkInterface = (*NIC)(nil) + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -54,9 +54,6 @@ type NIC struct { enabled bool spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint @@ -122,9 +119,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) - nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) nic.mu.ndp = ndpState{ nic: nic, @@ -137,6 +131,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.ndp.initializeTempAddrState() // Check for Neighbor Unreachability Detection support. + var nud NUDHandler if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) nic.neigh = &neighborCache{ @@ -144,6 +139,12 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC state: NewNUDState(stack.nudConfigs, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + // An interface value that holds a nil concrete value is itself non-nil. + // For this reason, n.neigh cannot be passed directly to NewEndpoint so + // NetworkEndpoints don't confuse it for non-nil. + // + // See https://golang.org/doc/faq#nil_error for more information. + nud = nic.neigh } // Register supported packet endpoint protocols. @@ -153,7 +154,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = nil - nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic.neigh, nic, ep, stack) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic, ep, stack) } nic.linkEP.Attach(nic) @@ -201,32 +202,21 @@ func (n *NIC) disableLocked() *tcpip.Error { // again, and applications may not know that the underlying NIC was ever // disabled. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + if ep, ok := n.networkEndpoints[header.IPv6ProtocolNumber]; ok { n.mu.ndp.stopSolicitingRouters() n.mu.ndp.cleanupState(false /* hostOnly */) // Stop DAD for all the unicast IPv6 endpoints that are in the // permanentTentative state. - for _, r := range n.mu.endpoints { - if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + for _, r := range ep.AllEndpoints() { + if addr := r.AddressWithPrefix().Address; r.GetKind() == PermanentTentative && header.IsV6UnicastAddress(addr) { n.mu.ndp.stopDuplicateAddressDetection(addr) } } - - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } } - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - - // The address may have already been removed. - if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + for _, ep := range n.networkEndpoints { + if err := ep.Disable(); err != nil { return err } } @@ -258,16 +248,8 @@ func (n *NIC) enable() *tcpip.Error { n.mu.enabled = true - // Create an endpoint to receive broadcast packets on this interface. - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } - - // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts - // multicast group. Note, the IANA calls the all-hosts multicast group the - // all-systems multicast group. - if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { + for _, ep := range n.networkEndpoints { + if err := ep.Enable(); err != nil { return err } } @@ -284,38 +266,31 @@ func (n *NIC) enable() *tcpip.Error { // link address if it is configured to do so. Note, each interface is // required to have IPv6 link-local unicast address, as per RFC 4291 // section 2.1. - _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] + ep, ok := n.networkEndpoints[header.IPv6ProtocolNumber] if !ok { return nil } - // Join the All-Nodes multicast group before starting DAD as responses to DAD - // (NDP NS) messages may be sent to the All-Nodes multicast group if the - // source address of the NDP NS is the unspecified address, as per RFC 4861 - // section 7.2.4. - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // // Addresses may have aleady completed DAD but in the time since the NIC was // last enabled, other devices may have acquired the same addresses. - for _, r := range n.mu.endpoints { - addr := r.address() - if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { + for _, r := range ep.AllEndpoints() { + addr := r.AddressWithPrefix().Address + if k := r.GetKind(); (k != Permanent && k != PermanentTentative) || !header.IsV6UnicastAddress(addr) { continue } - r.setKind(permanentTentative) - if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + ref := n.nepToRef(header.IPv6ProtocolNumber, ep, r) + ref.setKind(PermanentTentative) + if err := n.mu.ndp.startDuplicateAddressDetection(addr, ref); err != nil { return err } } // Do not auto-generate an IPv6 link-local address for loopback devices. - if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { + if n.stack.autoGenIPv6LinkLocal && !n.IsLoopback() { // The valid and preferred lifetime is infinite for the auto-generated // link-local address. n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) @@ -349,24 +324,22 @@ func (n *NIC) remove() *tcpip.Error { var err *tcpip.Error // Forcefully leave multicast groups. - for nid := range n.mu.mcastJoins { - if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { - err = tempErr - } - } - - // Remove permanent and permanentTentative addresses, so no packet goes out. - for nid, ref := range n.mu.endpoints { - switch ref.getKind() { - case permanentTentative, permanent: - if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if ok { + // Let ep.close handle this + if tempErr := gep.LeaveAllGroups(); tempErr != nil && err == nil { err = tempErr } } - } - // Release any resources the network endpoint may hold. - for _, ep := range n.networkEndpoints { + // Remove permanent and permanentTentative addresses, so no packet goes out. + // Release any resources the network endpoint may hold. + // Let ep.Close handle this. + if tempErr := ep.RemoveAllAddresses(); tempErr != nil && err == nil { + err = tempErr + } + ep.Close() } @@ -414,7 +387,8 @@ func (n *NIC) isPromiscuousMode() bool { return rv } -func (n *NIC) isLoopback() bool { +// IsLoopback implements NetworkInterface. +func (n *NIC) IsLoopback() bool { return n.linkEP.Capabilities()&CapabilityLoopback != 0 } @@ -432,188 +406,74 @@ func (n *NIC) setSpoofing(enable bool) { // If an IPv6 primary endpoint is requested, Source Address Selection (as // defined by RFC 6724 section 5) will be performed. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 { - return n.primaryIPv6Endpoint(remoteAddr) - } - n.mu.RLock() defer n.mu.RUnlock() - var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.mu.primary[protocol] { - if !r.isValidForOutgoingRLocked() { - continue - } - - if !r.deprecated { - if r.tryIncRef() { - // r is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - deprecatedEndpoint.decRefLocked() - deprecatedEndpoint = nil - } - - return r - } - } else if deprecatedEndpoint == nil && r.tryIncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of r in - // case n doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, r's reference count - // will be decremented when such an endpoint is found. - deprecatedEndpoint = r - } - } - - // n doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated - // endpoints either). - return deprecatedEndpoint -} - -// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC -// 6724 section 5). -type ipv6AddrCandidate struct { - ref *referencedNetworkEndpoint - scope header.IPv6AddressScope -} - -// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - n.mu.RLock() - ref := n.primaryIPv6EndpointRLocked(remoteAddr) - n.mu.RUnlock() - return ref + return n.primaryEndpointRLocked(protocol, remoteAddr) } -// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). +// primaryEndpointRLocked is like primaryEndpoint but without the locking +// requirements. // -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -// -// n.mu MUST be read locked. -func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] - - if len(primaryAddrs) == 0 { +// n.mu MUST be rad locked. +func (n *NIC) primaryEndpointRLocked(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { + ep, ok := n.networkEndpoints[protocol] + if !ok { return nil } - // Create a candidate set of available addresses we can potentially use as a - // source address. - cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) - for _, r := range primaryAddrs { - // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoingRLocked() { - continue - } - - addr := r.address() - scope, err := header.ScopeForIPv6Address(addr) - if err != nil { - // Should never happen as we got r from the primary IPv6 endpoint list and - // ScopeForIPv6Address only returns an error if addr is not an IPv6 - // address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) - } - - cs = append(cs, ipv6AddrCandidate{ - ref: r, - scope: scope, - }) - } - - remoteScope, err := header.ScopeForIPv6Address(remoteAddr) - if err != nil { - // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) - } - - // Sort the addresses as per RFC 6724 section 5 rules 1-3. - // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. - sort.Slice(cs, func(i, j int) bool { - sa := cs[i] - sb := cs[j] - - // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.address() == remoteAddr { - return true - } - if sb.ref.address() == remoteAddr { - return false - } - - // Prefer appropriate scope as per RFC 6724 section 5 rule 2. - if sa.scope < sb.scope { - return sa.scope >= remoteScope - } else if sb.scope < sa.scope { - return sb.scope < remoteScope - } - - // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. - if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { - // If sa is not deprecated, it is preferred over sb. - return sbDep - } - - // Prefer temporary addresses as per RFC 6724 section 5 rule 7. - if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { - return saTemp - } - - // sa and sb are equal, return the endpoint that is closest to the front of - // the primary endpoint list. - return i < j - }) - - // Return the most preferred address that can have its reference count - // incremented. - for _, c := range cs { - if r := c.ref; r.tryIncRef() { - return r - } + nep := ep.PrimaryEndpoint(remoteAddr, n.mu.spoofing) + if nep == nil { + return nil } - return nil + return n.nepToRef(protocol, ep, nep) } // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - - if !ok { - return false + for _, ep := range n.networkEndpoints { + if ep.HasAddress(addr) { + return true + } } - - kind := ref.getKind() - - return kind == permanent || kind == permanentTentative + return false } type getRefBehaviour int const ( + none getRefBehaviour = iota + // spoofing indicates that the NIC's spoofing flag should be observed when // getting a NIC's referenced network endpoint. - spoofing getRefBehaviour = iota + spoofing // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous ) +func (n *NIC) nepToRef(p tcpip.NetworkProtocolNumber, ep NetworkEndpoint, nep AddressEndpoint) *referencedNetworkEndpoint { + ref := &referencedNetworkEndpoint{ + ep: ep, + nep: nep, + protocol: p, + nic: n, + } + + // Set up cache if link address resolution exists for this protocol. + if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := n.stack.linkAddrResolvers[p]; ok { + ref.linkCache = n.stack + ref.linkRes = linkRes + } + } + + return ref +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } @@ -633,8 +493,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // If the address is the IPv4 broadcast address for an endpoint's network, that // endpoint will be returned. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - n.mu.RLock() - + n.mu.Lock() var spoofingOrPromiscuous bool switch tempRef { case spoofing: @@ -643,79 +502,11 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.promiscuous } - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // An endpoint with this id exists, check if it can be used and return it. - if !ref.isAssignedRLocked(spoofingOrPromiscuous) { - n.mu.RUnlock() - return nil - } - - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - } - - // Check if address is a broadcast address for the endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. - if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { - n.mu.RUnlock() - return ref - } - } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the address is found in the NIC's subnets and the NIC is - // a loopback interface. - createTempEP := spoofingOrPromiscuous - if !createTempEP && n.isLoopback() { - for _, r := range n.mu.endpoints { - addr := r.addrWithPrefix() - subnet := addr.Subnet() - if subnet.Contains(address) { - createTempEP = true - break - } - } - } - n.mu.RUnlock() - - if !createTempEP { - return nil - } - - // Try again with the lock in exclusive mode. If we still can't get the - // endpoint, create a new "temporary" endpoint. It will only exist while - // there's a route through it. - n.mu.Lock() - ref := n.getRefOrCreateTempLocked(protocol, address, peb) + ref := n.getRefOrCreateTempLocked(protocol, address, spoofingOrPromiscuous, peb) n.mu.Unlock() return ref } -// getRefForBroadcastLocked returns an endpoint where address is the IPv4 -// broadcast address for the endpoint's network. -// -// n.mu MUST be read locked. -func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { - for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses. - if ref.protocol != header.IPv4ProtocolNumber { - continue - } - - addr := ref.addrWithPrefix() - subnet := addr.Subnet() - if subnet.IsBroadcast(address) && ref.tryIncRef() { - return ref - } - } - - return nil -} - /// getRefOrCreateTempLocked returns an existing endpoint for address or creates /// and returns a temporary endpoint. // @@ -723,41 +514,23 @@ func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetwor // endpoint will be returned. // // n.mu must be write locked. -func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // No need to check the type as we are ok with expired endpoints at this - // point. - if ref.tryIncRef() { - return ref +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if protocol == 0 { + for p, ep := range n.networkEndpoints { + if nep := ep.GetAssignedEndpoint(address, n.IsLoopback(), createTemp, peb); nep != nil { + return n.nepToRef(p, ep, nep) + } } - // tryIncRef failing means the endpoint is scheduled to be removed once the - // lock is released. Remove it here so we can create a new (temporary) one. - // The removal logic waiting for the lock handles this case. - n.removeEndpointLocked(ref) - } - - // Check if address is a broadcast address for an endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. - if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { - return ref + } else { + ep, ok := n.networkEndpoints[protocol] + if ok { + if nep := ep.GetAssignedEndpoint(address, n.IsLoopback(), createTemp, peb); nep != nil { + return n.nepToRef(protocol, ep, nep) + } } } - // Add a new temporary endpoint. - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return nil - } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, temporary, static, false) - return ref + return nil } // addAddressLocked adds a new protocolAddress to n. @@ -765,122 +538,36 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add // If n already has the address in a non-permanent state, and the kind given is // permanent, that address will be promoted in place and its properties set to // the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP addresses before adding them. - - // Sanity check. - id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.mu.endpoints[id]; ok { - // Endpoint already exists. - if kind != permanent { - return nil, tcpip.ErrDuplicateAddress - } - switch ref.getKind() { - case permanentTentative, permanent: - // The NIC already have a permanent endpoint with that address. - return nil, tcpip.ErrDuplicateAddress - case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect the new peb, - // configType and deprecated status. - if ref.tryIncRef() { - // TODO(b/147748385): Perform Duplicate Address Detection when promoting - // an IPv6 endpoint to permanent. - ref.setKind(permanent) - ref.deprecated = deprecated - ref.configType = configType - - refs := n.mu.primary[ref.protocol] - for i, r := range refs { - if r == ref { - switch peb { - case CanBePrimaryEndpoint: - return ref, nil - case FirstPrimaryEndpoint: - if i == 0 { - return ref, nil - } - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - case NeverPrimaryEndpoint: - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - return ref, nil - } - } - } - - n.insertPrimaryEndpointLocked(ref, peb) - - return ref, nil - } - // tryIncRef failing means the endpoint is scheduled to be removed once - // the lock is released. Remove it here so we can create a new - // (permanent) one. The removal logic waiting for the lock handles this - // case. - n.removeEndpointLocked(ref) - } - } - - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind AddressKind, configType AddressConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { + ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol } - var nud NUDHandler - if n.neigh != nil { - // An interface value that holds a nil concrete value is itself non-nil. - // For this reason, n.neigh cannot be passed directly to NewEndpoint so - // NetworkEndpoints don't confuse it for non-nil. - // - // See https://golang.org/doc/faq#nil_error for more information. - nud = n.neigh - } - - // Create the new network endpoint. - ep := netProto.NewEndpoint(n.id, n.stack, nud, n, n.linkEP, n.stack) - - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) - // If the address is an IPv6 address and it is a permanent address, // mark it as tentative so it goes through the DAD process if the NIC is // enabled. If the NIC is not enabled, DAD will be started when the NIC is // enabled. - if isIPv6Unicast && kind == permanent { - kind = permanentTentative - } - ref := &referencedNetworkEndpoint{ - refs: 1, - addr: protocolAddress.AddressWithPrefix, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, - configType: configType, - deprecated: deprecated, - } - - // Set up resolver if link address resolution exists for this protocol. - if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { - ref.linkCache = n.stack - ref.linkRes = linkRes - } + nep, err := ep.AddAddress(protocolAddress.AddressWithPrefix, AddAddressOptions{ + Deprecated: deprecated, + ConfigType: AddressConfigType(configType), + Kind: AddressKind(kind), + PEB: peb, + }) + if err != nil { + return nil, err } + ref := n.nepToRef(protocolAddress.Protocol, ep, nep) - // If we are adding an IPv6 unicast address, join the solicited-node - // multicast address. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } + isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) + if isIPv6Unicast && kind == Permanent { + kind = PermanentTentative + ref.setKind(kind) } - n.mu.endpoints[id] = ref - - n.insertPrimaryEndpointLocked(ref, peb) - // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. - if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { + if isIPv6Unicast && kind == PermanentTentative && n.mu.enabled { if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } @@ -894,7 +581,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) + _, err := n.addAddressLocked(protocolAddress, peb, Permanent, AddressConfigStatic, false /* deprecated */) n.mu.Unlock() return err @@ -906,19 +593,11 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() - addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for _, ref := range n.mu.endpoints { - // Don't include tentative, expired or temporary endpoints to - // avoid confusion and prevent the caller from using those. - switch ref.getKind() { - case permanentExpired, temporary: - continue + var addrs []tcpip.ProtocolAddress + for p, ep := range n.networkEndpoints { + for _, a := range ep.AllAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: ref.addrWithPrefix(), - }) } return addrs } @@ -929,20 +608,9 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() var addrs []tcpip.ProtocolAddress - for proto, list := range n.mu.primary { - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints - // to avoid confusion and prevent the caller from using - // those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: ref.addrWithPrefix(), - }) + for p, ep := range n.networkEndpoints { + for _, a := range ep.PrimaryAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } return addrs @@ -954,108 +622,47 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - n.mu.RLock() - defer n.mu.RUnlock() - - list, ok := n.mu.primary[proto] - if !ok { + ref := n.primaryEndpoint(proto, "") + if ref == nil { return tcpip.AddressWithPrefix{} } + addr := ref.addrWithPrefix() + ref.decRef() + return addr +} - var deprecatedEndpoint *referencedNetworkEndpoint - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints to avoid confusion - // and prevent the caller from using those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + for p, ep := range n.networkEndpoints { + nep := ep.GetEndpoint(addr) + if nep == nil { continue } - if !ref.deprecated { - return ref.addrWithPrefix() + addrWithPrefix := nep.AddressWithPrefix() + if addrWithPrefix.Address != addr { + continue } - if deprecatedEndpoint == nil { - deprecatedEndpoint = ref + kind := nep.GetKind() + if kind != Permanent && kind != PermanentTentative { + return tcpip.ErrBadLocalAddress } - } - - if deprecatedEndpoint != nil { - return deprecatedEndpoint.addrWithPrefix() - } - - return tcpip.AddressWithPrefix{} -} - -// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required -// by peb. -// -// n MUST be locked. -func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { - switch peb { - case CanBePrimaryEndpoint: - n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) - case FirstPrimaryEndpoint: - n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) - } -} - -func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := NetworkEndpointID{LocalAddress: r.address()} - - // Nothing to do if the reference has already been replaced with a different - // one. This happens in the case where 1) this endpoint's ref count hit zero - // and was waiting (on the lock) to be removed and 2) the same address was - // re-added in the meantime by removing this endpoint from the list and - // adding a new one. - if n.mu.endpoints[id] != r { - return - } - - if r.getKind() == permanent { - panic("Reference count dropped to zero before being removed") - } - delete(n.mu.endpoints, id) - refs := n.mu.primary[r.protocol] - for i, ref := range refs { - if ref == r { - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - refs[len(refs)-1] = nil - break + ref := n.nepToRef(p, ep, nep) + switch p { + case header.IPv6ProtocolNumber: + return n.removePermanentIPv6EndpointLocked(ref, true /* allowSLAACInvalidation */) + default: + ref.expireLocked() + return nil } } -} -func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { - n.mu.Lock() - n.removeEndpointLocked(r) - n.mu.Unlock() -} - -func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadLocalAddress - } - - kind := r.getKind() - if kind != permanent && kind != permanentTentative { - return tcpip.ErrBadLocalAddress - } - - switch r.protocol { - case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) - default: - r.expireLocked() - return nil - } + return tcpip.ErrBadLocalAddress } func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { addr := r.addrWithPrefix() - isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) if isIPv6Unicast { @@ -1063,10 +670,10 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. - switch r.configType { - case slaac: + switch r.configType() { + case AddressConfigSlaac: n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - case slaacTemp: + case AddressConfigSlaacTemp: n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) } } @@ -1075,18 +682,6 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al // At this point the endpoint is deleted. - // If we are removing an IPv6 unicast address, leave the solicited-node - // multicast address. - // - // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node - // multicast group may be left by user action. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr.Address) - if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - } - return nil } @@ -1160,34 +755,27 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A // as an MLD packet's source address must be a link-local address as // outlined in RFC 3810 section 5. - id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - if joins == 0 { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return tcpip.ErrUnknownProtocol - } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported } - n.mu.mcastJoins[id] = joins + 1 - return nil + + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + _, err := gep.JoinGroup(addr) + return err } // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr, false /* force */) + return n.leaveGroupLocked(protocol, addr, false /* force */) } // leaveGroupLocked decrements the count for the given multicast address, and @@ -1196,32 +784,38 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { // // If force is true, then the count for the multicast addres is ignored and the // endpoint will be removed immediately. -func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { - id := NetworkEndpointID{addr} - joins, ok := n.mu.mcastJoins[id] +func (n *NIC) leaveGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, force bool) *tcpip.Error { + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported + } + + gep, ok := ep.(GroupAddressableEndpoint) if !ok { - // There are no joins with this address on this NIC. - return tcpip.ErrBadLocalAddress + return tcpip.ErrNotSupported } - joins-- - if force || joins == 0 { - // There are no outstanding joins or we are forced to leave, clean up. - delete(n.mu.mcastJoins, id) - return n.removePermanentAddressLocked(addr) + if _, err := gep.LeaveGroup(addr, force); err != nil { + return err } - n.mu.mcastJoins[id] = joins return nil } // isInGroup returns true if n has joined the multicast group addr. func (n *NIC) isInGroup(addr tcpip.Address) bool { - n.mu.RLock() - joins := n.mu.mcastJoins[NetworkEndpointID{addr}] - n.mu.RUnlock() + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + continue + } + + if gep.IsInGroup(addr) { + return true + } + } - return joins != 0 + return false } func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { @@ -1297,7 +891,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { + if n.stack.handleLocal && !n.IsLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a // packet like this unless handleLocal is false. Loopback also calls this // function even though the packets didn't come from the physical interface @@ -1308,7 +902,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if protocol == header.IPv4ProtocolNumber && !n.IsLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) @@ -1337,8 +931,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Found a NIC. n := r.ref.nic n.mu.RLock() - ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() + ref := n.getRefOrCreateTempLocked(protocol, dst, false, NeverPrimaryEndpoint) + ok := ref != nil && ref.isValidForOutgoingRLocked() n.mu.RUnlock() if ok { r.LocalLinkAddress = n.linkEP.LinkAddress() @@ -1512,7 +1106,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } } -// ID returns the identifier of n. +// ID implements NetworkInterface. func (n *NIC) ID() tcpip.NICID { return n.id } @@ -1538,15 +1132,12 @@ func (n *NIC) LinkEndpoint() LinkEndpoint { // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - n.mu.RLock() - defer n.mu.RUnlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { + nep := n.networkEndpoints[header.IPv6ProtocolNumber].GetEndpoint(addr) + if nep == nil { return false } - - return ref.getKind() == permanentTentative + kind := nep.GetKind() + return kind == PermanentTentative } // dupTentativeAddrDetected attempts to inform n that a tentative addr is a @@ -1559,12 +1150,14 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { + ep := n.networkEndpoints[header.IPv6ProtocolNumber] + nep := ep.GetEndpoint(addr) + if nep == nil { return tcpip.ErrBadAddress } - if ref.getKind() != permanentTentative { + ref := n.nepToRef(header.IPv6ProtocolNumber, ep, nep) + if ref.getKind() != PermanentTentative { return tcpip.ErrInvalidEndpointState } @@ -1576,10 +1169,10 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { prefix := ref.addrWithPrefix().Subnet() - switch ref.configType { - case slaac: + switch ref.configType() { + case AddressConfigSlaac: n.mu.ndp.regenerateSLAACAddr(prefix) - case slaacTemp: + case AddressConfigSlaacTemp: // Do not reset the generation attempts counter for the prefix as the // temporary address is being regenerated in response to a DAD conflict. n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) @@ -1629,41 +1222,6 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.ndp.handleRA(ip, ra) } -type networkEndpointKind int32 - -const ( - // A permanentTentative endpoint is a permanent address that is not yet - // considered to be fully bound to an interface in the traditional - // sense. That is, the address is associated with a NIC, but packets - // destined to the address MUST NOT be accepted and MUST be silently - // dropped, and the address MUST NOT be used as a source address for - // outgoing packets. For IPv6, addresses will be of this kind until - // NDP's Duplicate Address Detection has resolved, or be deleted if - // the process results in detecting a duplicate address. - permanentTentative networkEndpointKind = iota - - // A permanent endpoint is created by adding a permanent address (vs. a - // temporary one) to the NIC. Its reference count is biased by 1 to avoid - // removal when no route holds a reference to it. It is removed by explicitly - // removing the permanent address from the NIC. - permanent - - // An expired permanent endpoint is a permanent endpoint that had its address - // removed from the NIC, and it is waiting to be removed once no more routes - // hold a reference to it. This is achieved by decreasing its reference count - // by 1. If its address is re-added before the endpoint is removed, its type - // changes back to permanent and its reference count increases by 1 again. - permanentExpired - - // A temporary endpoint is created for spoofing outgoing packets, or when in - // promiscuous mode and accepting incoming packets that don't match any - // permanent endpoint. Its reference count is not biased by 1 and the - // endpoint is removed immediately when no more route holds a reference to - // it. A temporary endpoint can be promoted to permanent if its address - // is added permanently. - temporary -) - func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1694,27 +1252,9 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } -type networkEndpointConfigType int32 - -const ( - // A statically configured endpoint is an address that was added by - // some user-specified action (adding an explicit address, joining a - // multicast group). - static networkEndpointConfigType = iota - - // A SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4862 section 5.5.3. - slaac - - // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are - // not expected to be valid (or preferred) forever; hence the term temporary. - slaacTemp -) - type referencedNetworkEndpoint struct { ep NetworkEndpoint - addr tcpip.AddressWithPrefix + nep AddressEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -1725,39 +1265,22 @@ type referencedNetworkEndpoint struct { // linkRes is set if link address resolution is enabled for this protocol. // Set to nil otherwise. linkRes LinkAddressResolver - - // refs is counting references held for this endpoint. When refs hits zero it - // triggers the automatic removal of the endpoint from the NIC. - refs int32 - - // networkEndpointKind must only be accessed using {get,set}Kind(). - kind networkEndpointKind - - // configType is the method that was used to configure this endpoint. - // This must never change except during endpoint creation and promotion to - // permanent. - configType networkEndpointConfigType - - // deprecated indicates whether or not the endpoint should be considered - // deprecated. That is, when deprecated is true, other endpoints that are not - // deprecated should be preferred. - deprecated bool } func (r *referencedNetworkEndpoint) address() tcpip.Address { - return r.addr.Address + return r.nep.AddressWithPrefix().Address } func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return r.addr + return r.nep.AddressWithPrefix() } -func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { - return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +func (r *referencedNetworkEndpoint) getKind() AddressKind { + return r.nep.GetKind() } -func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { - atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +func (r *referencedNetworkEndpoint) setKind(kind AddressKind) { + r.nep.SetKind(kind) } // isValidForOutgoing returns true if the endpoint can be used to send out a @@ -1784,60 +1307,51 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { // // r.nic.mu must be read locked. func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { - switch r.getKind() { - case permanentTentative: - return false - case permanentExpired: - return spoofingOrPromiscuous - default: - return true - } + return r.nep.IsAssigned(spoofingOrPromiscuous) } // expireLocked decrements the reference count and marks the permanent endpoint // as expired. func (r *referencedNetworkEndpoint) expireLocked() { - r.setKind(permanentExpired) - r.decRefLocked() + _ = r.ep.RemoveAddress(r.address()) } // decRef decrements the ref count and cleans up the endpoint once it reaches // zero. func (r *referencedNetworkEndpoint) decRef() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpoint(r) - } + _ = r.nep.DecRef() } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is // locked. func (r *referencedNetworkEndpoint) decRefLocked() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpointLocked(r) - } + _ = r.nep.DecRef() } // incRef increments the ref count. It must only be called when the caller is // known to be holding a reference to the endpoint, otherwise tryIncRef should // be used. func (r *referencedNetworkEndpoint) incRef() { - atomic.AddInt32(&r.refs, 1) + _ = r.tryIncRef() } // tryIncRef attempts to increment the ref count from n to n+1, but only if n is // not zero. That is, it will increment the count if the endpoint is still // alive, and do nothing if it has already been clean up. func (r *referencedNetworkEndpoint) tryIncRef() bool { - for { - v := atomic.LoadInt32(&r.refs) - if v == 0 { - return false - } + return r.nep.IncRef() +} - if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { - return true - } - } +func (r *referencedNetworkEndpoint) setDeprecated(d bool) { + r.nep.SetDeprecated(d) +} + +func (r *referencedNetworkEndpoint) deprecated() bool { + return r.nep.Deprecated() +} + +func (r *referencedNetworkEndpoint) configType() AddressConfigType { + return r.nep.ConfigType() } // stack returns the Stack instance that owns the underlying endpoint. diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 1e065b5c1f..7d7a0004fe 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -94,6 +94,8 @@ func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol t panic("not implemented") } +var _ GroupAddressableEndpoint = (*testIPv6Endpoint)(nil) +var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. @@ -101,11 +103,21 @@ var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Endpoint struct { + AddressableEndpoint + nicID tcpip.NICID linkEP LinkEndpoint protocol *testIPv6Protocol } +func (*testIPv6Endpoint) Enable() *tcpip.Error { + return nil +} + +func (*testIPv6Endpoint) Disable() *tcpip.Error { + return nil +} + // DefaultTTL implements NetworkEndpoint.DefaultTTL. func (*testIPv6Endpoint) DefaultTTL() uint8 { return 0 @@ -161,6 +173,22 @@ func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } +func (*testIPv6Endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + return false, nil +} + +func (*testIPv6Endpoint) LeaveGroup(addr tcpip.Address, force bool) (bool, *tcpip.Error) { + return false, nil +} + +func (*testIPv6Endpoint) IsInGroup(addr tcpip.Address) bool { + return false +} + +func (*testIPv6Endpoint) LeaveAllGroups() *tcpip.Error { + return nil +} + var _ NetworkProtocol = (*testIPv6Protocol)(nil) // An IPv6 NetworkProtocol that supports the bare minimum to make a stack @@ -192,11 +220,12 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { return &testIPv6Endpoint{ - nicID: nicID, - linkEP: linkEP, - protocol: p, + AddressableEndpoint: NewAddressableEndpoint(), + nicID: nic.ID(), + linkEP: linkEP, + protocol: p, } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 21ac38583d..7e31dbcaeb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -229,6 +229,14 @@ type NetworkHeaderParams struct { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + AddressableEndpoint + + // Enable enables the endpoint. + Enable() *tcpip.Error + + // Disable disables the endpoint. + Disable() *tcpip.Error + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) // for this endpoint. DefaultTTL() uint8 @@ -298,7 +306,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7f5ed9e83d..64e55f638b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1157,7 +1157,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Up: true, // Netstack interfaces are always up. Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.isLoopback(), + Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -1291,7 +1291,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil } } } else { @@ -1307,7 +1307,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.address() } - r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) + r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) if len(route.Gateway) > 0 { @@ -1838,7 +1838,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { - return nic.leaveGroup(multicastAddr) + return nic.leaveGroup(protocol, multicastAddr) } return tcpip.ErrUnknownNICID } @@ -2025,16 +2025,14 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres defer s.mu.RUnlock() for _, nic := range s.nics { - id := NetworkEndpointID{address} - - if ref, ok := nic.mu.endpoints[id]; ok { - nic.mu.RLock() - defer nic.mu.RUnlock() - - // An endpoint with this id exists, check if it can be - // used and return it. - return ref.ep, nil + ref := nic.getRefOrCreateTemp(netProto, address, NeverPrimaryEndpoint, none) + if ref == nil { + continue } + + ep := ref.ep + ref.decRef() + return ep, nil } return nil, tcpip.ErrBadAddress } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1deeccb898..60b68ab7cd 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -68,12 +68,22 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { + stack.AddressableEndpoint + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } +func (*fakeNetworkEndpoint) Enable() *tcpip.Error { + return nil +} + +func (*fakeNetworkEndpoint) Disable() *tcpip.Error { + return nil +} + func (f *fakeNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) } @@ -197,12 +207,13 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { return &fakeNetworkEndpoint{ - nicID: nicID, - proto: f, - dispatcher: dispatcher, - ep: ep, + AddressableEndpoint: stack.NewAddressableEndpoint(), + nicID: nic.ID(), + proto: f, + dispatcher: dispatcher, + ep: ep, } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index adb32e4288..6026975d3a 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5296,8 +5296,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") tests := []struct { name string diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0cbc045d85..78224e86b0 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1462,6 +1462,18 @@ func TestNoChecksum(t *testing.T) { } } +var testNIC stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1485,7 +1497,7 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(testNIC, nil, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, })) @@ -1518,7 +1530,7 @@ func TestSetTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(testNIC, nil, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, }))