From d0113732fe1a3248456da84259f01967b4d09b49 Mon Sep 17 00:00:00 2001 From: Philipp Krivanec Date: Wed, 26 Apr 2023 06:02:54 +0200 Subject: [PATCH] optimize generateACLPeerCacheMap (#1377) --- acls.go | 15 +++++------ app.go | 2 +- machine.go | 73 ++++++++++++++++++++++++++++++++++++------------------ 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/acls.go b/acls.go index 2073ee8483..53bb702333 100644 --- a/acls.go +++ b/acls.go @@ -163,23 +163,20 @@ func (h *Headscale) UpdateACLRules() error { // generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map // of which Sources ("*" and IPs) can access destinations. This is to speed up the // process of generating MapResponses when deciding which Peers to inform nodes about. -func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]struct{} { - aclCachePeerMap := make(map[string]map[string]struct{}) +func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string { + aclCachePeerMap := make(map[string][]string) for _, rule := range rules { for _, srcIP := range rule.SrcIPs { for _, ip := range expandACLPeerAddr(srcIP) { if data, ok := aclCachePeerMap[ip]; ok { for _, dstPort := range rule.DstPorts { - for _, dstIP := range expandACLPeerAddr(dstPort.IP) { - data[dstIP] = struct{}{} - } + data = append(data, dstPort.IP) } + aclCachePeerMap[ip] = data } else { - dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) + dstPortsMap := make([]string, 0) for _, dstPort := range rule.DstPorts { - for _, dstIP := range expandACLPeerAddr(dstPort.IP) { - dstPortsMap[dstIP] = struct{}{} - } + dstPortsMap = append(dstPortsMap, dstPort.IP) } aclCachePeerMap[ip] = dstPortsMap } diff --git a/app.go b/app.go index 26a8e23bdd..480689bc22 100644 --- a/app.go +++ b/app.go @@ -87,7 +87,7 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules []tailcfg.FilterRule aclPeerCacheMapRW sync.RWMutex - aclPeerCacheMap map[string]map[string]struct{} + aclPeerCacheMap map[string][]string sshPolicy *tailcfg.SSHPolicy lastStateChange *xsync.MapOf[string, time.Time] diff --git a/machine.go b/machine.go index 6dfa9501af..1b70b1e207 100644 --- a/machine.go +++ b/machine.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "errors" "fmt" + "net" "net/netip" "sort" "strconv" @@ -172,7 +173,7 @@ func filterMachinesByACL( machine *Machine, machines Machines, lock *sync.RWMutex, - aclPeerCacheMap map[string]map[string]struct{}, + aclPeerCacheMap map[string][]string, ) Machines { log.Trace(). Caller(). @@ -197,27 +198,38 @@ func filterMachinesByACL( if dstMap, ok := aclPeerCacheMap["*"]; ok { // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer - continue + for _, dst := range dstMap { + if dst == "*" { + peers[peer.ID] = peer + + continue + } } // match source and all destination for _, peerIP := range peerIPs { - if _, dstOk := dstMap[peerIP]; dstOk { - peers[peer.ID] = peer + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(peerIP) + if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer - continue + continue + } } } // match all sources and source for _, machineIP := range machineIPs { - if _, dstOk := dstMap[machineIP]; dstOk { - peers[peer.ID] = peer + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(machineIP) + if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer - continue + continue + } } } } @@ -225,18 +237,24 @@ func filterMachinesByACL( for _, machineIP := range machineIPs { if dstMap, ok := aclPeerCacheMap[machineIP]; ok { // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer + for _, dst := range dstMap { + if dst == "*" { + peers[peer.ID] = peer - continue + continue + } } // match source and destination for _, peerIP := range peerIPs { - if _, dstOk := dstMap[peerIP]; dstOk { - peers[peer.ID] = peer - - continue + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(peerIP) + if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer + + continue + } } } } @@ -245,17 +263,24 @@ func filterMachinesByACL( for _, peerIP := range peerIPs { if dstMap, ok := aclPeerCacheMap[peerIP]; ok { // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer + for _, dst := range dstMap { + if dst == "*" { + peers[peer.ID] = peer - continue + continue + } } + // match return path for _, machineIP := range machineIPs { - if _, dstOk := dstMap[machineIP]; dstOk { - peers[peer.ID] = peer - - continue + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(machineIP) + if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer + + continue + } } } }