Skip to content

Commit

Permalink
simplify expandAlias function, move seperate logic out
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby authored and juanfont committed May 3, 2023
1 parent b23a915 commit 6de53e2
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 175 deletions.
261 changes: 153 additions & 108 deletions acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,8 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {

principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources {
expandedSrcs, err := expandAlias(
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
*h.aclPolicy,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
Expand Down Expand Up @@ -391,16 +390,16 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {

func generateACLPolicySrc(
machines []Machine,
aclPolicy ACLPolicy,
pol ACLPolicy,
src string,
stripEmaildomain bool,
) ([]string, error) {
return expandAlias(machines, aclPolicy, src, stripEmaildomain)
return pol.expandAlias(machines, src, stripEmaildomain)
}

func generateACLPolicyDest(
machines []Machine,
aclPolicy ACLPolicy,
pol ACLPolicy,
dest string,
needsWildcard bool,
stripEmaildomain bool,
Expand Down Expand Up @@ -448,9 +447,8 @@ func generateACLPolicyDest(
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}

expanded, err := expandAlias(
expanded, err := pol.expandAlias(
machines,
aclPolicy,
alias,
stripEmaildomain,
)
Expand Down Expand Up @@ -534,13 +532,11 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// - an ip
// - a cidr
// and transform these in IPAddresses.
func expandAlias(
func (pol *ACLPolicy) expandAlias(
machines Machines,
aclPolicy ACLPolicy,
alias string,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}
if alias == "*" {
return []string{"*"}, nil
}
Expand All @@ -549,128 +545,56 @@ func expandAlias(
Str("alias", alias).
Msg("Expanding")

// if alias is a group
if strings.HasPrefix(alias, "group:") {
users, err := expandGroup(aclPolicy, alias, stripEmailDomain)
if err != nil {
return ips, err
}
for _, n := range users {
nodes := filterMachinesByUser(machines, n)
for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...)
}
}

return ips, nil
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
}

// if alias is a tag
if strings.HasPrefix(alias, "tag:") {
// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}
}

// find tag owners
owners, err := expandTagOwners(aclPolicy, alias, stripEmailDomain)
if err != nil {
if errors.Is(err, errInvalidTag) {
if len(ips) == 0 {
return ips, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag,
alias,
)
}

return ips, nil
} else {
return ips, err
}
}

// filter out machines per tag owner
for _, user := range owners {
machines := filterMachinesByUser(machines, user)
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}
}
}

return ips, nil
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
}

// if alias is a user
nodes := filterMachinesByUser(machines, alias)
nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias, stripEmailDomain)

for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
}
if len(ips) > 0 {
if ips := pol.getIPsForUser(alias, machines, stripEmailDomain); len(ips) > 0 {
return ips, nil
}

// if alias is an host
if h, ok := aclPolicy.Hosts[alias]; ok {
if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry")

return expandAlias(machines, aclPolicy, h.String(), stripEmailDomain)
return pol.expandAlias(machines, h.String(), stripEmailDomain)
}

// if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
ips := []string{ip.String()}
matches := machines.FilterByIP(ip)

for _, machine := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}

return lo.Uniq(ips), nil
return pol.getIPsFromSingleIP(ip, machines)
}

if cidr, err := netip.ParsePrefix(alias); err == nil {
log.Trace().Str("cidr", cidr.String()).Msg("expandAlias got cidr")
val := []string{cidr.String()}
// This is suboptimal and quite expensive, but if we only add the cidr, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines {
for _, ip := range machine.IPAddresses {
// log.Trace().
// Msgf("checking if machine ip (%s) is part of cidr (%s): %v, is single ip cidr (%v), addr: %s", ip.String(), cidr.String(), cidr.Contains(ip), cidr.IsSingleIP(), cidr.Addr().String())
if cidr.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...)
}
}
}

return lo.Uniq(val), nil
// if alias is an IP Prefix (CIDR)
if prefix, err := netip.ParsePrefix(alias); err == nil {
return pol.getIPsFromIPPrefix(prefix, machines)
}

log.Warn().Msgf("No IPs found with the alias %v", alias)

return ips, nil
return []string{}, nil
}

// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
// that are correctly tagged since they should not be listed as being in the user
// we assume in this function that we only have nodes from 1 user.
func excludeCorrectlyTaggedNodes(
aclPolicy ACLPolicy,
aclPolicy *ACLPolicy,
nodes []Machine,
user string,
stripEmailDomain bool,
) []Machine {
out := []Machine{}
tags := []string{}
for tag := range aclPolicy.TagOwners {
owners, _ := expandTagOwners(aclPolicy, user, stripEmailDomain)
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
ns := append(owners, user)
if contains(ns, user) {
tags = append(tags, tag)
Expand Down Expand Up @@ -758,15 +682,15 @@ func filterMachinesByUser(machines []Machine, user string) []Machine {
return out
}

// expandTagOwners will return a list of user. An owner can be either a user or a group
// getTagOwners will return a list of user. An owner can be either a user or a group
// a group cannot be composed of groups.
func expandTagOwners(
aclPolicy ACLPolicy,
func getTagOwners(
pol *ACLPolicy,
tag string,
stripEmailDomain bool,
) ([]string, error) {
var owners []string
ows, ok := aclPolicy.TagOwners[tag]
ows, ok := pol.TagOwners[tag]
if !ok {
return []string{}, fmt.Errorf(
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
Expand All @@ -776,7 +700,7 @@ func expandTagOwners(
}
for _, owner := range ows {
if strings.HasPrefix(owner, "group:") {
gs, err := expandGroup(aclPolicy, owner, stripEmailDomain)
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
if err != nil {
return []string{}, err
}
Expand All @@ -789,15 +713,15 @@ func expandTagOwners(
return owners, nil
}

// expandGroup will return the list of user inside the group
// getUsersInGroup will return the list of user inside the group
// after some validation.
func expandGroup(
aclPolicy ACLPolicy,
func (pol *ACLPolicy) getUsersInGroup(
group string,
stripEmailDomain bool,
) ([]string, error) {
outGroups := []string{}
aclGroups, ok := aclPolicy.Groups[group]
users := []string{}
log.Trace().Caller().Interface("pol", pol).Msg("test")
aclGroups, ok := pol.Groups[group]
if !ok {
return []string{}, fmt.Errorf(
"group %v isn't registered. %w",
Expand All @@ -820,8 +744,129 @@ func expandGroup(
errInvalidGroup,
)
}
outGroups = append(outGroups, grp)
users = append(users, grp)
}

return users, nil
}

func (pol *ACLPolicy) getIPsFromGroup(
group string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}

users, err := pol.getUsersInGroup(group, stripEmailDomain)
if err != nil {
return ips, err
}
for _, n := range users {
nodes := filterMachinesByUser(machines, n)
for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...)
}
}

return ips, nil
}

func (pol *ACLPolicy) getIPsFromTag(
alias string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}

// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}
}

// find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil {
if errors.Is(err, errInvalidTag) {
if len(ips) == 0 {
return ips, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag,
alias,
)
}

return ips, nil
} else {
return ips, err
}
}

// filter out machines per tag owner
for _, user := range owners {
machines := filterMachinesByUser(machines, user)
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}
}
}

return ips, nil
}

func (pol *ACLPolicy) getIPsForUser(
user string,
machines Machines,
stripEmailDomain bool,
) []string {
ips := []string{}

nodes := filterMachinesByUser(machines, user)
nodes = excludeCorrectlyTaggedNodes(pol, nodes, user, stripEmailDomain)

for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
}

return ips
}

func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr,
machines Machines,
) ([]string, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")

ips := []string{ip.String()}
matches := machines.FilterByIP(ip)

for _, machine := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
}

return lo.Uniq(ips), nil
}

func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix,
machines Machines,
) ([]string, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
val := []string{prefix.String()}
// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines {
for _, ip := range machine.IPAddresses {
// log.Trace().
// Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
if prefix.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...)
}
}
}

return outGroups, nil
return lo.Uniq(val), nil
}
Loading

0 comments on commit 6de53e2

Please sign in to comment.