From 92a7b8e30ff3812d37d89605227eace705861a7b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 11 May 2023 09:09:18 +0200 Subject: [PATCH 1/2] create DB struct This is step one in detaching the Database layer from Headscale (h). The ultimate goal is to have all function that does database operations in its own package, and keep the business logic and writing separate. Signed-off-by: Kristoffer Dalby --- .gitignore | 2 + cmd/headscale/cli/api_key.go | 4 +- cmd/headscale/cli/debug.go | 4 +- cmd/headscale/cli/nodes.go | 8 +- cmd/headscale/cli/users.go | 6 +- cmd/headscale/cli/utils.go | 5 +- cmd/headscale/headscale_test.go | 5 +- hscontrol/acls.go | 38 +- hscontrol/acls_test.go | 574 +++++++++--------- hscontrol/addresses.go | 98 +++ .../{utils_test.go => addresses_test.go} | 64 +- hscontrol/api.go | 25 +- hscontrol/api_common.go | 11 +- hscontrol/api_key.go | 38 +- hscontrol/api_key_test.go | 32 +- hscontrol/app.go | 81 ++- hscontrol/app_test.go | 24 +- hscontrol/config.go | 15 +- hscontrol/db.go | 182 +++--- hscontrol/dns_test.go | 64 +- hscontrol/grpcv1.go | 78 +-- hscontrol/machine.go | 256 ++++---- hscontrol/machine_test.go | 257 ++++---- hscontrol/oidc.go | 51 +- hscontrol/preauth_keys.go | 59 +- hscontrol/preauth_keys_test.go | 82 +-- hscontrol/protocol_common.go | 53 +- hscontrol/protocol_common_poll.go | 27 +- hscontrol/protocol_common_utils.go | 14 +- hscontrol/protocol_legacy.go | 5 +- hscontrol/protocol_legacy_poll.go | 7 +- hscontrol/protocol_noise_poll.go | 6 +- hscontrol/routes.go | 105 ++-- hscontrol/routes_test.go | 189 +++--- hscontrol/users.go | 65 +- hscontrol/users_test.go | 90 +-- hscontrol/util/addr.go | 42 ++ hscontrol/util/file.go | 43 ++ hscontrol/util/key.go | 117 ++++ hscontrol/util/net.go | 12 + hscontrol/util/string.go | 85 +++ hscontrol/util/string_test.go | 15 + hscontrol/utils.go | 361 ----------- integration/auth_oidc_test.go | 3 +- integration/embedded_derp_test.go | 4 +- integration/hsic/hsic.go | 5 +- integration/scenario.go | 4 +- integration/tsic/tsic.go | 4 +- 48 files changed, 1739 insertions(+), 1580 deletions(-) create mode 100644 hscontrol/addresses.go rename hscontrol/{utils_test.go => addresses_test.go} (75%) create mode 100644 hscontrol/util/addr.go create mode 100644 hscontrol/util/file.go create mode 100644 hscontrol/util/key.go create mode 100644 hscontrol/util/net.go create mode 100644 hscontrol/util/string.go create mode 100644 hscontrol/util/string_test.go delete mode 100644 hscontrol/utils.go diff --git a/.gitignore b/.gitignore index bcbc9b2ab1..0ba8193185 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,5 @@ integration_test/etc/config.dump.yaml # MkDocs .cache /site + +__debug_bin diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index f7c7e3a264..37ef423514 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -6,7 +6,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/pterm/pterm" "github.com/rs/zerolog/log" @@ -83,7 +83,7 @@ var listAPIKeys = &cobra.Command{ } tableData = append(tableData, []string{ - strconv.FormatUint(key.GetId(), hscontrol.Base10), + strconv.FormatUint(key.GetId(), util.Base10), key.GetPrefix(), expiration, key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat), diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index f2c8028f25..7e8e92dc38 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -4,7 +4,7 @@ import ( "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -93,7 +93,7 @@ var createNodeCmd = &cobra.Command{ return } - if !hscontrol.NodePublicKeyRegex.Match([]byte(machineKey)) { + if !util.NodePublicKeyRegex.Match([]byte(machineKey)) { err = errPreAuthKeyMalformed ErrorOutput( err, diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 772b428e82..31a0677394 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -10,7 +10,7 @@ import ( survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/pterm/pterm" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -529,7 +529,7 @@ func nodesToPtables( var machineKey key.MachinePublic err := machineKey.UnmarshalText( - []byte(hscontrol.MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil { machineKey = key.MachinePublic{} @@ -537,7 +537,7 @@ func nodesToPtables( var nodeKey key.NodePublic err = nodeKey.UnmarshalText( - []byte(hscontrol.NodePublicKeyEnsurePrefix(machine.NodeKey)), + []byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)), ) if err != nil { return nil, err @@ -596,7 +596,7 @@ func nodesToPtables( } nodeData := []string{ - strconv.FormatUint(machine.Id, hscontrol.Base10), + strconv.FormatUint(machine.Id, util.Base10), machine.Name, machine.GetGivenName(), machineKey.ShortString(), diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 3724fe98e4..3132e99503 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -1,11 +1,11 @@ package cli import ( + "errors" "fmt" survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" "github.com/pterm/pterm" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -20,9 +20,7 @@ func init() { userCmd.AddCommand(renameUserCmd) } -const ( - errMissingParameter = hscontrol.Error("missing parameters") -) +var errMissingParameter = errors.New("missing parameters") var userCmd = &cobra.Command{ Use: "users", diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index a2a5d59251..2831dbf775 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,6 +10,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -39,7 +40,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { // We are doing this here, as in the future could be cool to have it also hot-reload if cfg.ACL.PolicyPath != "" { - aclPath := hscontrol.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) + aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) err = app.LoadACLPolicyFromPath(aclPath) if err != nil { log.Fatal(). @@ -98,7 +99,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. grpcOptions = append( grpcOptions, grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(hscontrol.GrpcSocketDialer), + grpc.WithContextDialer(util.GrpcSocketDialer), ) } else { // If we are not connecting to a local server, require an API key for authentication diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 1b987313bb..89fd775440 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/spf13/viper" "gopkg.in/check.v1" ) @@ -64,7 +65,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") c.Assert( - hscontrol.GetFileMode("unix_socket_permission"), + util.GetFileMode("unix_socket_permission"), check.Equals, fs.FileMode(0o770), ) @@ -107,7 +108,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") c.Assert( - hscontrol.GetFileMode("unix_socket_permission"), + util.GetFileMode("unix_socket_permission"), check.Equals, fs.FileMode(0o770), ) diff --git a/hscontrol/acls.go b/hscontrol/acls.go index 449c7ffd3d..2c81046a49 100644 --- a/hscontrol/acls.go +++ b/hscontrol/acls.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/tailscale/hujson" "go4.org/netipx" @@ -20,21 +21,16 @@ import ( "tailscale.com/tailcfg" ) -const ( - errEmptyPolicy = Error("empty policy") - errInvalidAction = Error("invalid action") - errInvalidGroup = Error("invalid group") - errInvalidTag = Error("invalid tag") - errInvalidPortFormat = Error("invalid port format") - errWildcardIsNeeded = Error("wildcard as port is required for the protocol") +var ( + errEmptyPolicy = errors.New("empty policy") + errInvalidAction = errors.New("invalid action") + errInvalidGroup = errors.New("invalid group") + errInvalidTag = errors.New("invalid tag") + errInvalidPortFormat = errors.New("invalid port format") + errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") ) const ( - Base8 = 8 - Base10 = 10 - BitSize16 = 16 - BitSize32 = 32 - BitSize64 = 64 portRangeBegin = 0 portRangeEnd = 65535 expectedTokenItems = 2 @@ -123,7 +119,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { } func (h *Headscale) UpdateACLRules() error { - machines, err := h.ListMachines() + machines, err := h.db.ListMachines() if err != nil { return err } @@ -230,7 +226,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { return nil, errEmptyPolicy } - machines, err := h.ListMachines() + machines, err := h.db.ListMachines() if err != nil { return nil, err } @@ -570,7 +566,7 @@ func excludeCorrectlyTaggedNodes( for tag := range aclPolicy.TagOwners { owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) ns := append(owners, user) - if contains(ns, user) { + if util.StringOrPrefixListContains(ns, user) { tags = append(tags, tag) } } @@ -580,7 +576,7 @@ func excludeCorrectlyTaggedNodes( found := false for _, t := range hi.RequestTags { - if contains(tags, t) { + if util.StringOrPrefixListContains(tags, t) { found = true break @@ -614,7 +610,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err rang := strings.Split(portStr, "-") switch len(rang) { case 1: - port, err := strconv.ParseUint(rang[0], Base10, BitSize16) + port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) if err != nil { return nil, err } @@ -624,11 +620,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err }) case expectedTokenItems: - start, err := strconv.ParseUint(rang[0], Base10, BitSize16) + start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) if err != nil { return nil, err } - last, err := strconv.ParseUint(rang[1], Base10, BitSize16) + last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16) if err != nil { return nil, err } @@ -754,7 +750,7 @@ func (pol *ACLPolicy) getIPsFromTag( // check for forced tags for _, machine := range machines { - if contains(machine.ForcedTags, alias) { + if util.StringOrPrefixListContains(machine.ForcedTags, alias) { machine.IPAddresses.AppendToIPSet(&build) } } @@ -783,7 +779,7 @@ func (pol *ACLPolicy) getIPsFromTag( machines := filterMachinesByUser(machines, user) for _, machine := range machines { hi := machine.GetHostInfo() - if contains(hi.RequestTags, alias) { + if util.StringOrPrefixListContains(hi.RequestTags, alias) { machine.IPAddresses.AppendToIPSet(&build) } } diff --git a/hscontrol/acls_test.go b/hscontrol/acls_test.go index 095597f273..70a57b81ab 100644 --- a/hscontrol/acls_test.go +++ b/hscontrol/acls_test.go @@ -30,8 +30,8 @@ func (s *Suite) TestBrokenHuJson(c *check.C) { func (s *Suite) TestInvalidPolicyHuson(c *check.C) { acl := []byte(` { - "valid_json": true, - "but_a_policy_though": false + "valid_json": true, + "but_a_policy_though": false } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -60,129 +60,129 @@ func (s *Suite) TestParseInvalidCIDR(c *check.C) { func (s *Suite) TestRuleInvalidGeneration(c *check.C) { acl := []byte(` { - // Declare static groups of users beyond those in the identity service. - "groups": { - "group:example": [ - "user1@example.com", - "user2@example.com", - ], - }, - // Declare hostname aliases to use in place of IP addresses or subnets. - "hosts": { - "example-host-1": "100.100.100.100", - "example-host-2": "100.100.101.100/24", - }, - // Define who is allowed to use which tags. - "tagOwners": { - // Everyone in the montreal-admins or global-admins group are - // allowed to tag servers as montreal-webserver. - "tag:montreal-webserver": [ - "group:montreal-admins", - "group:global-admins", - ], - // Only a few admins are allowed to create API servers. - "tag:api-server": [ - "group:global-admins", - "example-host-1", - ], - }, - // Access control lists. - "acls": [ - // Engineering users, plus the president, can access port 22 (ssh) - // and port 3389 (remote desktop protocol) on all servers, and all - // ports on git-server or ci-server. - { - "action": "accept", - "src": [ - "group:engineering", - "president@example.com" - ], - "dst": [ - "*:22,3389", - "git-server:*", - "ci-server:*" - ], - }, - // Allow engineer users to access any port on a device tagged with - // tag:production. - { - "action": "accept", - "src": [ - "group:engineers" - ], - "dst": [ - "tag:production:*" - ], - }, - // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts - // on both networks. - { - "action": "accept", - "src": [ - "my-subnet", - "192.168.1.0/24" - ], - "dst": [ - "my-subnet:*", - "192.168.1.0/24:*" - ], - }, - // Allow every user of your network to access anything on the network. - // Comment out this section if you want to define specific ACL - // restrictions above. - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "*:*" - ], - }, - // All users in Montreal are allowed to access the Montreal web - // servers. - { - "action": "accept", - "src": [ - "group:montreal-users" - ], - "dst": [ - "tag:montreal-webserver:80,443" - ], - }, - // Montreal web servers are allowed to make outgoing connections to - // the API servers, but only on https port 443. - // In contrast, this doesn't grant API servers the right to initiate - // any connections. - { - "action": "accept", - "src": [ - "tag:montreal-webserver" - ], - "dst": [ - "tag:api-server:443" - ], - }, - ], - // Declare tests to check functionality of ACL rules - "tests": [ - { - "src": "user1@example.com", - "accept": [ - "example-host-1:22", - "example-host-2:80" - ], - "deny": [ - "exapmle-host-2:100" - ], - }, - { - "src": "user2@example.com", - "accept": [ - "100.60.3.4:22" - ], - }, - ], + // Declare static groups of users beyond those in the identity service. + "groups": { + "group:example": [ + "user1@example.com", + "user2@example.com", + ], + }, + // Declare hostname aliases to use in place of IP addresses or subnets. + "hosts": { + "example-host-1": "100.100.100.100", + "example-host-2": "100.100.101.100/24", + }, + // Define who is allowed to use which tags. + "tagOwners": { + // Everyone in the montreal-admins or global-admins group are + // allowed to tag servers as montreal-webserver. + "tag:montreal-webserver": [ + "group:montreal-admins", + "group:global-admins", + ], + // Only a few admins are allowed to create API servers. + "tag:api-server": [ + "group:global-admins", + "example-host-1", + ], + }, + // Access control lists. + "acls": [ + // Engineering users, plus the president, can access port 22 (ssh) + // and port 3389 (remote desktop protocol) on all servers, and all + // ports on git-server or ci-server. + { + "action": "accept", + "src": [ + "group:engineering", + "president@example.com" + ], + "dst": [ + "*:22,3389", + "git-server:*", + "ci-server:*" + ], + }, + // Allow engineer users to access any port on a device tagged with + // tag:production. + { + "action": "accept", + "src": [ + "group:engineers" + ], + "dst": [ + "tag:production:*" + ], + }, + // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts + // on both networks. + { + "action": "accept", + "src": [ + "my-subnet", + "192.168.1.0/24" + ], + "dst": [ + "my-subnet:*", + "192.168.1.0/24:*" + ], + }, + // Allow every user of your network to access anything on the network. + // Comment out this section if you want to define specific ACL + // restrictions above. + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "*:*" + ], + }, + // All users in Montreal are allowed to access the Montreal web + // servers. + { + "action": "accept", + "src": [ + "group:montreal-users" + ], + "dst": [ + "tag:montreal-webserver:80,443" + ], + }, + // Montreal web servers are allowed to make outgoing connections to + // the API servers, but only on https port 443. + // In contrast, this doesn't grant API servers the right to initiate + // any connections. + { + "action": "accept", + "src": [ + "tag:montreal-webserver" + ], + "dst": [ + "tag:api-server:443" + ], + }, + ], + // Declare tests to check functionality of ACL rules + "tests": [ + { + "src": "user1@example.com", + "accept": [ + "example-host-1:22", + "example-host-2:80" + ], + "deny": [ + "exapmle-host-2:100" + ], + }, + { + "src": "user2@example.com", + "accept": [ + "100.60.3.4:22" + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -192,24 +192,24 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { func (s *Suite) TestBasicRule(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - "192.168.1.0/24" - ], - "dst": [ - "*:22,3389", - "host-1:*", - ], - }, - ], + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -238,13 +238,13 @@ func (s *Suite) TestInvalidAction(c *check.C) { func (s *Suite) TestSshRules(c *check.C) { envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -264,7 +264,7 @@ func (s *Suite) TestSshRules(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{ @@ -348,13 +348,13 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Sources section. func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -374,7 +374,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, @@ -398,13 +398,13 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Destinations section. func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -424,7 +424,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, @@ -448,13 +448,13 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { // tag on a host that isn't owned by a tag owners. So the user // of the host should be valid. func (s *Suite) TestInvalidTagValidUser(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "testmachine") + _, err = app.db.GetMachine("user1", "testmachine") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -474,7 +474,7 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ TagOwners: TagOwners{"tag:test": []string{"user1"}}, @@ -497,13 +497,13 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { // an ACL rule is matching the tag to a user. It should not be valid since the // host should be tied to the tag now. func (s *Suite) TestValidTagInvalidUser(c *check.C) { - user, err := app.CreateUser("user1") + user, err := app.db.CreateUser("user1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user1", "webserver") + _, err = app.db.GetMachine("user1", "webserver") c.Assert(err, check.NotNil) hostInfo := tailcfg.Hostinfo{ OS: "centos", @@ -523,8 +523,8 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) - _, err = app.GetMachine("user1", "user") + app.db.db.Save(&machine) + _, err = app.db.GetMachine("user1", "user") hostInfo2 := tailcfg.Hostinfo{ OS: "debian", Hostname: "Hostname", @@ -542,7 +542,7 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo2), } - app.db.Save(&machine) + app.db.db.Save(&machine) app.aclPolicy = &ACLPolicy{ TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, @@ -571,22 +571,22 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { func (s *Suite) TestPortRange(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - ], - "dst": [ - "host-1:5400-5500", - ], - }, - ], + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -605,43 +605,43 @@ func (s *Suite) TestPortRange(c *check.C) { func (s *Suite) TestProtocolParsing(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "tcp", - "dst": [ - "host-1:*", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "udp", - "dst": [ - "host-1:53", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "icmp", - "dst": [ - "host-1:*", - ], - }, - ], + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -660,22 +660,22 @@ func (s *Suite) TestProtocolParsing(c *check.C) { func (s *Suite) TestPortWildcard(c *check.C) { acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") @@ -694,8 +694,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { } func (s *Suite) TestPortWildcardYAML(c *check.C) { - acl := []byte(` ---- + acl := []byte(`--- hosts: host-1: 100.100.100.100/32 subnet-1: 100.100.101.100/24 @@ -704,8 +703,7 @@ acls: src: - "*" dst: - - host-1:* -`) + - host-1:*`) err := app.LoadACLPolicyFromBytes(acl, "yaml") c.Assert(err, check.IsNil) @@ -722,15 +720,15 @@ acls: } func (s *Suite) TestPortUser(c *check.C) { - user, err := app.CreateUser("testuser") + user, err := app.db.CreateUser("testuser") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("testuser", "testmachine") + _, err = app.db.GetMachine("testuser", "testmachine") c.Assert(err, check.NotNil) - ips, _ := app.getAvailableIPs() + ips, _ := app.db.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "12345", @@ -742,32 +740,32 @@ func (s *Suite) TestPortUser(c *check.C) { IPAddresses: ips, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) acl := []byte(` { - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err = app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) rules, err := app.aclPolicy.generateFilterRules(machines, false) @@ -785,15 +783,15 @@ func (s *Suite) TestPortUser(c *check.C) { } func (s *Suite) TestPortGroup(c *check.C) { - user, err := app.CreateUser("testuser") + user, err := app.db.CreateUser("testuser") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("testuser", "testmachine") + _, err = app.db.GetMachine("testuser", "testmachine") c.Assert(err, check.NotNil) - ips, _ := app.getAvailableIPs() + ips, _ := app.db.getAvailableIPs() machine := Machine{ ID: 0, MachineKey: "foo", @@ -805,38 +803,38 @@ func (s *Suite) TestPortGroup(c *check.C) { IPAddresses: ips, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) acl := []byte(` { - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], } `) err = app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) rules, err := app.aclPolicy.generateFilterRules(machines, false) diff --git a/hscontrol/addresses.go b/hscontrol/addresses.go new file mode 100644 index 0000000000..7f78935f8e --- /dev/null +++ b/hscontrol/addresses.go @@ -0,0 +1,98 @@ +// Codehere is mostly taken from github.com/tailscale/tailscale +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hscontrol + +import ( + "errors" + "fmt" + "net/netip" + + "github.com/juanfont/headscale/hscontrol/util" + "go4.org/netipx" +) + +var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") + +func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { + var ips MachineAddresses + var err error + for _, ipPrefix := range hsdb.ipPrefixes { + var ip *netip.Addr + ip, err = hsdb.getAvailableIP(ipPrefix) + if err != nil { + return ips, err + } + ips = append(ips, *ip) + } + + return ips, err +} + +func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { + usedIps, err := hsdb.getUsedIPs() + if err != nil { + return nil, err + } + + ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix) + + // Get the first IP in our prefix + ip := ipPrefixNetworkAddress.Next() + + for { + if !ipPrefix.Contains(ip) { + return nil, ErrCouldNotAllocateIP + } + + switch { + case ip.Compare(ipPrefixBroadcastAddress) == 0: + fallthrough + case usedIps.Contains(ip): + fallthrough + case ip == netip.Addr{} || ip.IsLoopback(): + ip = ip.Next() + + continue + + default: + return &ip, nil + } + } +} + +func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { + // FIXME: This really deserves a better data model, + // but this was quick to get running and it should be enough + // to begin experimenting with a dual stack tailnet. + var addressesSlices []string + hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) + + var ips netipx.IPSetBuilder + for _, slice := range addressesSlices { + var machineAddresses MachineAddresses + err := machineAddresses.Scan(slice) + if err != nil { + return &netipx.IPSet{}, fmt.Errorf( + "failed to read ip from database: %w", + err, + ) + } + + for _, ip := range machineAddresses { + ips.Add(ip) + } + } + + ipSet, err := ips.IPSet() + if err != nil { + return &netipx.IPSet{}, fmt.Errorf( + "failed to build IP Set: %w", + err, + ) + } + + return ipSet, nil +} diff --git a/hscontrol/utils_test.go b/hscontrol/addresses_test.go similarity index 75% rename from hscontrol/utils_test.go rename to hscontrol/addresses_test.go index 436df8ac88..f3be93aab1 100644 --- a/hscontrol/utils_test.go +++ b/hscontrol/addresses_test.go @@ -8,7 +8,7 @@ import ( ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) @@ -19,16 +19,16 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { } func (s *Suite) TestGetUsedIps(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) - user, err := app.CreateUser("test-ip") + user, err := app.db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -42,9 +42,9 @@ func (s *Suite) TestGetUsedIps(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.Save(&machine) + app.db.db.Save(&machine) - usedIps, err := app.getUsedIPs() + usedIps, err := app.db.getUsedIPs() c.Assert(err, check.IsNil) @@ -56,7 +56,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) c.Assert(usedIps.Contains(expected), check.Equals, true) - machine1, err := app.GetMachineByID(0) + machine1, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) @@ -64,19 +64,19 @@ func (s *Suite) TestGetUsedIps(c *check.C) { } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := app.CreateUser("test-ip-multi") + user, err := app.db.CreateUser("test-ip-multi") c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { - app.ipAllocationMutex.Lock() + app.db.ipAllocationMutex.Lock() - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -90,12 +90,12 @@ func (s *Suite) TestGetMultiIp(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.Save(&machine) + app.db.db.Save(&machine) - app.ipAllocationMutex.Unlock() + app.db.ipAllocationMutex.Unlock() } - usedIps, err := app.getUsedIPs() + usedIps, err := app.db.getUsedIPs() c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -117,7 +117,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs - machine1, err := app.GetMachineByID(1) + machine1, err := app.db.GetMachineByID(1) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert( @@ -126,7 +126,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { netip.MustParseAddr("10.27.0.1"), ) - machine50, err := app.GetMachineByID(50) + machine50, err := app.db.GetMachineByID(50) c.Assert(err, check.IsNil) c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert( @@ -136,7 +136,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { ) expectedNextIP := netip.MustParseAddr("10.27.1.95") - nextIP, err := app.getAvailableIPs() + nextIP, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP), check.Equals, 1) @@ -144,7 +144,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := app.getAvailableIPs() + nextIP2, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP2), check.Equals, 1) @@ -152,7 +152,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ips, err := app.getAvailableIPs() + ips, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -160,13 +160,13 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) - user, err := app.CreateUser("test-ip") + user, err := app.db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := Machine{ @@ -179,23 +179,11 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - ips2, err := app.getAvailableIPs() + ips2, err := app.db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(ips2), check.Equals, 1) c.Assert(ips2[0].String(), check.Equals, expected.String()) } - -func (s *Suite) TestGenerateRandomStringDNSSafe(c *check.C) { - for i := 0; i < 100000; i++ { - str, err := GenerateRandomStringDNSSafe(8) - if err != nil { - c.Error(err) - } - if len(str) != 8 { - c.Error("invalid length", len(str), str) - } - } -} diff --git a/hscontrol/api.go b/hscontrol/api.go index f8b1496f64..8e3014199a 100644 --- a/hscontrol/api.go +++ b/hscontrol/api.go @@ -3,25 +3,28 @@ package hscontrol import ( "bytes" "encoding/json" + "errors" "html/template" "net/http" "time" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/types/key" ) const ( // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. - registrationHoldoff = time.Second * 5 - reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authkey" - RegisterMethodOIDC = "oidc" - RegisterMethodCLI = "cli" - ErrRegisterMethodCLIDoesNotSupportExpire = Error( - "machines registered with CLI does not support expire", - ) + registrationHoldoff = time.Second * 5 + reservedResponseHeaderSize = 4 + RegisterMethodAuthKey = "authkey" + RegisterMethodOIDC = "oidc" + RegisterMethodCLI = "cli" +) + +var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( + "machines registered with CLI does not support expire", ) func (h *Headscale) HealthHandler( @@ -53,7 +56,7 @@ func (h *Headscale) HealthHandler( } } - if err := h.pingDB(req.Context()); err != nil { + if err := h.db.pingDB(req.Context()); err != nil { respond(err) return @@ -95,7 +98,7 @@ func (h *Headscale) RegisterWebAPI( vars := mux.Vars(req) nodeKeyStr, ok := vars["nkey"] - if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -116,7 +119,7 @@ func (h *Headscale) RegisterWebAPI( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), ) if !ok || nodeKeyStr == "" || err != nil { diff --git a/hscontrol/api_common.go b/hscontrol/api_common.go index 3dd65ac6fa..f1b3fd8300 100644 --- a/hscontrol/api_common.go +++ b/hscontrol/api_common.go @@ -3,6 +3,7 @@ package hscontrol import ( "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -15,7 +16,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -26,7 +27,7 @@ func (h *Headscale) generateMapResponse( return nil, err } - peers, err := h.getValidPeers(machine) + peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) if err != nil { log.Error(). Caller(). @@ -37,9 +38,9 @@ func (h *Headscale) generateMapResponse( return nil, err } - profiles := h.getMapResponseUserProfiles(*machine, peers) + profiles := h.db.getMapResponseUserProfiles(*machine, peers) - nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -107,7 +108,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). // Interface("payload", resp). - Msgf("Generated map response: %s", tailMapResponseToString(resp)) + Msgf("Generated map response: %s", util.TailMapResponseToString(resp)) return &resp, nil } diff --git a/hscontrol/api_key.go b/hscontrol/api_key.go index 6382a33193..bf2ccf3942 100644 --- a/hscontrol/api_key.go +++ b/hscontrol/api_key.go @@ -1,11 +1,13 @@ package hscontrol import ( + "errors" "fmt" "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -13,10 +15,10 @@ import ( const ( apiPrefixLength = 7 apiKeyLength = 32 - - ErrAPIKeyFailedToParse = Error("Failed to parse ApiKey") ) +var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") + // APIKey describes the datamodel for API keys used to remotely authenticate with // headscale. type APIKey struct { @@ -30,15 +32,15 @@ type APIKey struct { } // CreateAPIKey creates a new ApiKey in a user, and returns it. -func (h *Headscale) CreateAPIKey( +func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *APIKey, error) { - prefix, err := GenerateRandomStringURLSafe(apiPrefixLength) + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err } - toBeHashed, err := GenerateRandomStringURLSafe(apiKeyLength) + toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) if err != nil { return "", nil, err } @@ -57,7 +59,7 @@ func (h *Headscale) CreateAPIKey( Expiration: expiration, } - if err := h.db.Save(&key).Error; err != nil { + if err := hsdb.db.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -65,9 +67,9 @@ func (h *Headscale) CreateAPIKey( } // ListAPIKeys returns the list of ApiKeys for a user. -func (h *Headscale) ListAPIKeys() ([]APIKey, error) { +func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { keys := []APIKey{} - if err := h.db.Find(&keys).Error; err != nil { + if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err } @@ -75,9 +77,9 @@ func (h *Headscale) ListAPIKeys() ([]APIKey, error) { } // GetAPIKey returns a ApiKey for a given key. -func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) { +func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { key := APIKey{} - if result := h.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -85,9 +87,9 @@ func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) { } // GetAPIKeyByID returns a ApiKey for a given id. -func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) { +func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { key := APIKey{} - if result := h.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -96,8 +98,8 @@ func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. -func (h *Headscale) DestroyAPIKey(key APIKey) error { - if result := h.db.Unscoped().Delete(key); result.Error != nil { +func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { + if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -105,21 +107,21 @@ func (h *Headscale) DestroyAPIKey(key APIKey) error { } // ExpireAPIKey marks a ApiKey as expired. -func (h *Headscale) ExpireAPIKey(key *APIKey) error { - if err := h.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { +func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { + if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } return nil } -func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) { +func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse } - key, err := h.GetAPIKey(prefix) + key, err := hsdb.GetAPIKey(prefix) if err != nil { return false, fmt.Errorf("failed to validate api key: %w", err) } diff --git a/hscontrol/api_key_test.go b/hscontrol/api_key_test.go index fd4fa00db9..007b5d1642 100644 --- a/hscontrol/api_key_test.go +++ b/hscontrol/api_key_test.go @@ -7,7 +7,7 @@ import ( ) func (*Suite) TestCreateAPIKey(c *check.C) { - apiKeyStr, apiKey, err := app.CreateAPIKey(nil) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) @@ -16,74 +16,74 @@ func (*Suite) TestCreateAPIKey(c *check.C) { c.Assert(apiKey.Hash, check.NotNil) c.Assert(apiKeyStr, check.Not(check.Equals), "") - _, err = app.ListAPIKeys() + _, err = app.db.ListAPIKeys() c.Assert(err, check.IsNil) - keys, err := app.ListAPIKeys() + keys, err := app.db.ListAPIKeys() c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) } func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { - key, err := app.GetAPIKey("does-not-exist") + key, err := app.db.GetAPIKey("does-not-exist") c.Assert(err, check.NotNil) c.Assert(key, check.IsNil) } func (*Suite) TestValidateAPIKeyOk(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) } func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowMinus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, false) now := time.Now() - apiKeyStrNow, apiKey, err := app.CreateAPIKey(&now) + apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - validNow, err := app.ValidateAPIKey(apiKeyStrNow) + validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) c.Assert(err, check.IsNil) c.Assert(validNow, check.Equals, false) - validSilly, err := app.ValidateAPIKey("nota.validkey") + validSilly, err := app.db.ValidateAPIKey("nota.validkey") c.Assert(err, check.NotNil) c.Assert(validSilly, check.Equals, false) - validWithErr, err := app.ValidateAPIKey("produceerrorkey") + validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") c.Assert(err, check.NotNil) c.Assert(validWithErr, check.Equals, false) } func (*Suite) TestExpireAPIKey(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.ValidateAPIKey(apiKeyStr) + valid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) - err = app.ExpireAPIKey(apiKey) + err = app.db.ExpireAPIKey(apiKey) c.Assert(err, check.IsNil) c.Assert(apiKey.Expiration, check.NotNil) - notValid, err := app.ValidateAPIKey(apiKeyStr) + notValid, err := app.db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) } diff --git a/hscontrol/app.go b/hscontrol/app.go index b8dceba8ae..38d4ec8cca 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -23,6 +23,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -41,24 +42,21 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" ) -const ( - errSTUNAddressNotSet = Error("STUN address not set") - errUnsupportedDatabase = Error("unsupported DB") - errUnsupportedLetsEncryptChallengeType = Error( +var ( + errSTUNAddressNotSet = errors.New("STUN address not set") + errUnsupportedDatabase = errors.New("unsupported DB") + errUnsupportedLetsEncryptChallengeType = errors.New( "unknown value for Lets Encrypt challenge type", ) ) const ( AuthPrefix = "Bearer " - Postgres = "postgres" - Sqlite = "sqlite3" updateInterval = 5000 HTTPReadTimeout = 30 * time.Second HTTPShutdownTimeout = 3 * time.Second @@ -75,7 +73,7 @@ const ( // Headscale represents the base app of the service. type Headscale struct { cfg *Config - db *gorm.DB + db *HSDatabase dbString string dbType string dbDebug bool @@ -96,10 +94,11 @@ type Headscale struct { registrationCache *cache.Cache - ipAllocationMutex sync.Mutex - shutdownChan chan struct{} pollNetMapStreamWG sync.WaitGroup + + stateUpdateChan chan struct{} + cancelStateUpdateChan chan struct{} } func NewHeadscale(cfg *Config) (*Headscale, error) { @@ -164,13 +163,27 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, lastStateChange: xsync.NewMapOf[time.Time](), + + stateUpdateChan: make(chan struct{}), + cancelStateUpdateChan: make(chan struct{}), } - err = app.initDB() + go app.watchStateChannel() + + db, err := NewHeadscaleDatabase( + cfg.DBtype, + dbString, + cfg.OIDC.StripEmaildomain, + app.dbDebug, + app.stateUpdateChan, + cfg.IPPrefixes, + cfg.BaseDomain) if err != nil { return nil, err } + app.db = db + if cfg.OIDC.Issuer != "" { err = app.initOIDC() if err != nil { @@ -231,7 +244,7 @@ func (h *Headscale) expireExpiredMachines(milliSeconds int64) { func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - err := h.handlePrimarySubnetFailover() + err := h.db.handlePrimarySubnetFailover() if err != nil { log.Error().Err(err).Msg("failed to handle primary subnet failover") } @@ -239,7 +252,7 @@ func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { } func (h *Headscale) expireEphemeralNodesWorker() { - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -247,7 +260,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { } for _, user := range users { - machines, err := h.ListMachinesByUser(user.Name) + machines, err := h.db.ListMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -267,7 +280,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") - err = h.db.Unscoped().Delete(machine).Error + err = h.db.db.Unscoped().Delete(machine).Error if err != nil { log.Error(). Err(err). @@ -284,7 +297,7 @@ func (h *Headscale) expireEphemeralNodesWorker() { } func (h *Headscale) expireExpiredMachinesWorker() { - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -292,7 +305,7 @@ func (h *Headscale) expireExpiredMachinesWorker() { } for _, user := range users { - machines, err := h.ListMachinesByUser(user.Name) + machines, err := h.db.ListMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -308,7 +321,7 @@ func (h *Headscale) expireExpiredMachinesWorker() { machine.Expiry.After(h.getLastStateChange(user)) { expiredFound = true - err := h.ExpireMachine(&machines[index]) + err := h.db.ExpireMachine(&machines[index]) if err != nil { log.Error(). Err(err). @@ -387,7 +400,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, ) } - valid, err := h.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) + valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) if err != nil { log.Error(). Caller(). @@ -438,7 +451,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler return } - valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) + valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if err != nil { log.Error(). Caller(). @@ -597,7 +610,7 @@ func (h *Headscale) Serve() error { h.cfg.UnixSocket, []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(GrpcSocketDialer), + grpc.WithContextDialer(util.GrpcSocketDialer), }..., ) if err != nil { @@ -760,7 +773,7 @@ func (h *Headscale) Serve() error { // TODO(kradalby): Reload config on SIGHUP if h.cfg.ACL.PolicyPath != "" { - aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) + aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) err := h.LoadACLPolicyFromPath(aclPath) if err != nil { log.Error().Err(err).Msg("Failed to reload ACL policy") @@ -778,6 +791,7 @@ func (h *Headscale) Serve() error { Msg("Received signal to stop, shutting down gracefully") close(h.shutdownChan) + h.pollNetMapStreamWG.Wait() // Gracefully shut down servers @@ -806,8 +820,12 @@ func (h *Headscale) Serve() error { // Stop listening (and unlink the socket if unix type): socketListener.Close() + <-h.cancelStateUpdateChan + close(h.stateUpdateChan) + close(h.cancelStateUpdateChan) + // Close db connections - db, err := h.db.DB() + db, err := h.db.db.DB() if err != nil { log.Error().Err(err).Msg("Failed to get db handle") } @@ -905,12 +923,25 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } +// TODO(kradalby): baby steps, make this more robust. +func (h *Headscale) watchStateChannel() { + for { + select { + case <-h.stateUpdateChan: + h.setLastStateChangeToNow() + + case <-h.cancelStateUpdateChan: + return + } + } +} + func (h *Headscale) setLastStateChangeToNow() { var err error now := time.Now().UTC() - users, err := h.ListUsers() + users, err := h.db.ListUsers() if err != nil { log.Error(). Caller(). @@ -1002,7 +1033,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { } trimmedPrivateKey := strings.TrimSpace(string(privateKey)) - privateKeyEnsurePrefix := PrivateKeyEnsurePrefix(trimmedPrivateKey) + privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey) var machineKey key.MachinePrivate if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil { diff --git a/hscontrol/app_test.go b/hscontrol/app_test.go index 7d3907d3f8..1b4e91e827 100644 --- a/hscontrol/app_test.go +++ b/hscontrol/app_test.go @@ -42,18 +42,32 @@ func (s *Suite) ResetDB(c *check.C) { IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, + OIDC: OIDCConfig{ + StripEmaildomain: false, + }, } + // TODO(kradalby): make this use NewHeadscale properly so it doesnt drift app = Headscale{ cfg: &cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", + + stateUpdateChan: make(chan struct{}), + cancelStateUpdateChan: make(chan struct{}), } - err = app.initDB() - if err != nil { - c.Fatal(err) - } - db, err := app.openDB() + + go app.watchStateChannel() + + db, err := NewHeadscaleDatabase( + app.dbType, + app.dbString, + cfg.OIDC.StripEmaildomain, + false, + app.stateUpdateChan, + cfg.IPPrefixes, + "", + ) if err != nil { c.Fatal(err) } diff --git a/hscontrol/config.go b/hscontrol/config.go index 0e83a1c2d9..63deace05e 100644 --- a/hscontrol/config.go +++ b/hscontrol/config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -271,15 +272,15 @@ func GetTLSConfig() TLSConfig { LetsEncrypt: LetsEncryptConfig{ Hostname: viper.GetString("tls_letsencrypt_hostname"), Listen: viper.GetString("tls_letsencrypt_listen"), - CacheDir: AbsolutePathFromConfigPath( + CacheDir: util.AbsolutePathFromConfigPath( viper.GetString("tls_letsencrypt_cache_dir"), ), ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), }, - CertPath: AbsolutePathFromConfigPath( + CertPath: util.AbsolutePathFromConfigPath( viper.GetString("tls_cert_path"), ), - KeyPath: AbsolutePathFromConfigPath( + KeyPath: util.AbsolutePathFromConfigPath( viper.GetString("tls_key_path"), ), } @@ -585,10 +586,10 @@ func GetHeadscaleConfig() (*Config, error) { DisableUpdateCheck: viper.GetBool("disable_check_updates"), IPPrefixes: prefixes, - PrivateKeyPath: AbsolutePathFromConfigPath( + PrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("private_key_path"), ), - NoisePrivateKeyPath: AbsolutePathFromConfigPath( + NoisePrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("noise.private_key_path"), ), BaseDomain: baseDomain, @@ -604,7 +605,7 @@ func GetHeadscaleConfig() (*Config, error) { ), DBtype: viper.GetString("db_type"), - DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")), + DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), DBhost: viper.GetString("db_host"), DBport: viper.GetInt("db_port"), DBname: viper.GetString("db_name"), @@ -620,7 +621,7 @@ func GetHeadscaleConfig() (*Config, error) { ACMEURL: viper.GetString("acme_url"), UnixSocket: viper.GetString("unix_socket"), - UnixSocketPermission: GetFileMode("unix_socket_permission"), + UnixSocketPermission: util.GetFileMode("unix_socket_permission"), OIDC: OIDCConfig{ OnlyStartIfOIDCIsAvailable: viper.GetBool( diff --git a/hscontrol/db.go b/hscontrol/db.go index 14df4b3bf1..e80a3c3ed9 100644 --- a/hscontrol/db.go +++ b/hscontrol/db.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/netip" + "sync" "time" "github.com/glebarez/sqlite" @@ -19,55 +20,90 @@ import ( const ( dbVersion = "1" + Postgres = "postgres" + Sqlite = "sqlite3" +) - errValueNotFound = Error("not found") - ErrCannotParsePrefix = Error("cannot parse prefix") +var ( + errValueNotFound = errors.New("not found") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + errDatabaseNotSupported = errors.New("database type not supported") ) // KV is a key-value store in a psql table. For future use... +// TODO(kradalby): Is this used for anything? type KV struct { Key string Value string } -func (h *Headscale) initDB() error { - db, err := h.openDB() +type HSDatabase struct { + db *gorm.DB + notifyStateChan chan<- struct{} + + ipAllocationMutex sync.Mutex + + ipPrefixes []netip.Prefix + baseDomain string + stripEmailDomain bool +} + +// TODO(kradalby): assemble this struct from toptions or something typed +// rather than arguments. +func NewHeadscaleDatabase( + dbType, connectionAddr string, + stripEmailDomain, debug bool, + notifyStateChan chan<- struct{}, + ipPrefixes []netip.Prefix, + baseDomain string, +) (*HSDatabase, error) { + dbConn, err := openDB(dbType, connectionAddr, debug) if err != nil { - return err + return nil, err } - h.db = db - if h.dbType == Postgres { - db.Exec(`create extension if not exists "uuid-ossp";`) + db := HSDatabase{ + db: dbConn, + notifyStateChan: notifyStateChan, + + ipPrefixes: ipPrefixes, + baseDomain: baseDomain, + stripEmailDomain: stripEmailDomain, } - _ = db.Migrator().RenameTable("namespaces", "users") + log.Debug().Msgf("database %#v", dbConn) - err = db.AutoMigrate(&User{}) + if dbType == Postgres { + dbConn.Exec(`create extension if not exists "uuid-ossp";`) + } + + _ = dbConn.Migrator().RenameTable("namespaces", "users") + + err = dbConn.AutoMigrate(User{}) if err != nil { - return err + return nil, err } - _ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") - _ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") - _ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") - _ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") // GivenName is used as the primary source of DNS names, make sure // the field is populated and normalized if it was not when the // machine was registered. - _ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") + _ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") // If the Machine table has a column for registered, // find all occourences of "false" and drop them. Then // remove the column. - if db.Migrator().HasColumn(&Machine{}, "registered") { + if dbConn.Migrator().HasColumn(&Machine{}, "registered") { log.Info(). Msg(`Database has legacy "registered" column in machine, removing...`) machines := Machines{} - if err := h.db.Not("registered").Find(&machines).Error; err != nil { + if err := dbConn.Not("registered").Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -76,7 +112,7 @@ func (h *Headscale) initDB() error { Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). Msg("Deleting unregistered machine") - if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil { + if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { log.Error(). Err(err). Str("machine", machine.Hostname). @@ -85,18 +121,18 @@ func (h *Headscale) initDB() error { } } - err := db.Migrator().DropColumn(&Machine{}, "registered") + err := dbConn.Migrator().DropColumn(&Machine{}, "registered") if err != nil { log.Error().Err(err).Msg("Error dropping registered column") } } - err = db.AutoMigrate(&Route{}) + err = dbConn.AutoMigrate(&Route{}) if err != nil { - return err + return nil, err } - if db.Migrator().HasColumn(&Machine{}, "enabled_routes") { + if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") type MachineAux struct { @@ -105,7 +141,7 @@ func (h *Headscale) initDB() error { } machinesAux := []MachineAux{} - err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error + err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error if err != nil { log.Fatal().Err(err).Msg("Error accessing db") } @@ -120,7 +156,7 @@ func (h *Headscale) initDB() error { continue } - err = db.Preload("Machine"). + err = dbConn.Preload("Machine"). Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). First(&Route{}). Error @@ -138,7 +174,7 @@ func (h *Headscale) initDB() error { Enabled: true, Prefix: IPPrefix(prefix), } - if err := h.db.Create(&route).Error; err != nil { + if err := dbConn.Create(&route).Error; err != nil { log.Error().Err(err).Msg("Error creating route") } else { log.Info(). @@ -149,20 +185,20 @@ func (h *Headscale) initDB() error { } } - err = db.Migrator().DropColumn(&Machine{}, "enabled_routes") + err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") if err != nil { log.Error().Err(err).Msg("Error dropping enabled_routes column") } } - err = db.AutoMigrate(&Machine{}) + err = dbConn.AutoMigrate(&Machine{}) if err != nil { - return err + return nil, err } - if db.Migrator().HasColumn(&Machine{}, "given_name") { + if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { machines := Machines{} - if err := h.db.Find(&machines).Error; err != nil { + if err := dbConn.Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -170,7 +206,7 @@ func (h *Headscale) initDB() error { if machine.GivenName == "" { normalizedHostname, err := NormalizeToFQDNRules( machine.Hostname, - h.cfg.OIDC.StripEmaildomain, + stripEmailDomain, ) if err != nil { log.Error(). @@ -180,7 +216,7 @@ func (h *Headscale) initDB() error { Msg("Failed to normalize machine hostname in DB migration") } - err = h.RenameMachine(&machines[item], normalizedHostname) + err = db.RenameMachine(&machines[item], normalizedHostname) if err != nil { log.Error(). Caller(). @@ -192,51 +228,51 @@ func (h *Headscale) initDB() error { } } - err = db.AutoMigrate(&KV{}) + err = dbConn.AutoMigrate(&KV{}) if err != nil { - return err + return nil, err } - err = db.AutoMigrate(&PreAuthKey{}) + err = dbConn.AutoMigrate(&PreAuthKey{}) if err != nil { - return err + return nil, err } - err = db.AutoMigrate(&PreAuthKeyACLTag{}) + err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) if err != nil { - return err + return nil, err } - _ = db.Migrator().DropTable("shared_machines") + _ = dbConn.Migrator().DropTable("shared_machines") - err = db.AutoMigrate(&APIKey{}) + err = dbConn.AutoMigrate(&APIKey{}) if err != nil { - return err + return nil, err } - err = h.setValue("db_version", dbVersion) + // TODO(kradalby): is this needed? + err = db.setValue("db_version", dbVersion) - return err + return &db, err } -func (h *Headscale) openDB() (*gorm.DB, error) { - var db *gorm.DB - var err error +func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { + log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") - var log logger.Interface - if h.dbDebug { - log = logger.Default + var dbLogger logger.Interface + if debug { + dbLogger = logger.Default } else { - log = logger.Default.LogMode(logger.Silent) + dbLogger = logger.Default.LogMode(logger.Silent) } - switch h.dbType { + switch dbType { case Sqlite: - db, err = gorm.Open( - sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"), + db, err := gorm.Open( + sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, - Logger: log, + Logger: dbLogger, }, ) @@ -250,24 +286,30 @@ func (h *Headscale) openDB() (*gorm.DB, error) { sqlDB.SetMaxOpenConns(1) sqlDB.SetConnMaxIdleTime(time.Hour) + return db, err + case Postgres: - db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{ + return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, - Logger: log, + Logger: dbLogger, }) } - if err != nil { - return nil, err - } + return nil, fmt.Errorf( + "database of type %s is not supported: %w", + dbType, + errDatabaseNotSupported, + ) +} - return db, nil +func (hsdb *HSDatabase) notifyStateChange() { + hsdb.notifyStateChan <- struct{}{} } // getValue returns the value for the given key in KV. -func (h *Headscale) getValue(key string) (string, error) { +func (hsdb *HSDatabase) getValue(key string) (string, error) { var row KV - if result := h.db.First(&row, "key = ?", key); errors.Is( + if result := hsdb.db.First(&row, "key = ?", key); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -278,34 +320,34 @@ func (h *Headscale) getValue(key string) (string, error) { } // setValue sets value for the given key in KV. -func (h *Headscale) setValue(key string, value string) error { +func (hsdb *HSDatabase) setValue(key string, value string) error { keyValue := KV{ Key: key, Value: value, } - if _, err := h.getValue(key); err == nil { - h.db.Model(&keyValue).Where("key = ?", key).Update("value", value) + if _, err := hsdb.getValue(key); err == nil { + hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value) return nil } - if err := h.db.Create(keyValue).Error; err != nil { + if err := hsdb.db.Create(keyValue).Error; err != nil { return fmt.Errorf("failed to create key value pair in the database: %w", err) } return nil } -func (h *Headscale) pingDB(ctx context.Context) error { +func (hsdb *HSDatabase) pingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - db, err := h.db.DB() + sqlDB, err := hsdb.db.DB() if err != nil { return err } - return db.PingContext(ctx) + return sqlDB.PingContext(ctx) } // This is a "wrapper" type around tailscales diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go index b825721913..671a712f45 100644 --- a/hscontrol/dns_test.go +++ b/hscontrol/dns_test.go @@ -112,16 +112,16 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) { } func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyInShared1, err := app.CreatePreAuthKey( + preAuthKeyInShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -130,7 +130,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared2, err := app.CreatePreAuthKey( + preAuthKeyInShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -139,7 +139,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared3, err := app.CreatePreAuthKey( + preAuthKeyInShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -148,7 +148,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - PreAuthKey2InShared1, err := app.CreatePreAuthKey( + PreAuthKey2InShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -157,7 +157,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -172,9 +172,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -189,9 +189,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -206,9 +206,9 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -223,7 +223,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -232,7 +232,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Proxied: true, } - peersOfMachineInShared1, err := app.getPeers(machineInShared1) + peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( @@ -259,16 +259,16 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { } func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyInShared1, err := app.CreatePreAuthKey( + preAuthKeyInShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -277,7 +277,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared2, err := app.CreatePreAuthKey( + preAuthKeyInShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -286,7 +286,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyInShared3, err := app.CreatePreAuthKey( + preAuthKeyInShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -295,7 +295,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKey2InShared1, err := app.CreatePreAuthKey( + preAuthKey2InShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -304,7 +304,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -319,9 +319,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -336,9 +336,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -353,9 +353,9 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -370,7 +370,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -379,7 +379,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Proxied: false, } - peersOfMachine1Shared1, err := app.getPeers(machineInShared1) + peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a65a380503..4a26d08eb7 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,6 +8,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -30,7 +31,7 @@ func (api headscaleV1APIServer) GetUser( ctx context.Context, request *v1.GetUserRequest, ) (*v1.GetUserResponse, error) { - user, err := api.h.GetUser(request.GetName()) + user, err := api.h.db.GetUser(request.GetName()) if err != nil { return nil, err } @@ -42,7 +43,7 @@ func (api headscaleV1APIServer) CreateUser( ctx context.Context, request *v1.CreateUserRequest, ) (*v1.CreateUserResponse, error) { - user, err := api.h.CreateUser(request.GetName()) + user, err := api.h.db.CreateUser(request.GetName()) if err != nil { return nil, err } @@ -54,12 +55,12 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.RenameUser(request.GetOldName(), request.GetNewName()) + err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) if err != nil { return nil, err } - user, err := api.h.GetUser(request.GetNewName()) + user, err := api.h.db.GetUser(request.GetNewName()) if err != nil { return nil, err } @@ -71,7 +72,7 @@ func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.DestroyUser(request.GetName()) + err := api.h.db.DestroyUser(request.GetName()) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (api headscaleV1APIServer) ListUsers( ctx context.Context, request *v1.ListUsersRequest, ) (*v1.ListUsersResponse, error) { - users, err := api.h.ListUsers() + users, err := api.h.db.ListUsers() if err != nil { return nil, err } @@ -116,7 +117,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } - preAuthKey, err := api.h.CreatePreAuthKey( + preAuthKey, err := api.h.db.CreatePreAuthKey( request.GetUser(), request.GetReusable(), request.GetEphemeral(), @@ -134,12 +135,12 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.GetPreAuthKey(request.GetUser(), request.Key) + preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) if err != nil { return nil, err } - err = api.h.ExpirePreAuthKey(preAuthKey) + err = api.h.db.ExpirePreAuthKey(preAuthKey) if err != nil { return nil, err } @@ -151,7 +152,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.ListPreAuthKeys(request.GetUser()) + preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) if err != nil { return nil, err } @@ -173,7 +174,8 @@ func (api headscaleV1APIServer) RegisterMachine( Str("node_key", request.GetKey()). Msg("Registering machine") - machine, err := api.h.RegisterMachineFromAuthCallback( + machine, err := api.h.db.RegisterMachineFromAuthCallback( + api.h.registrationCache, request.GetKey(), request.GetUser(), nil, @@ -190,7 +192,7 @@ func (api headscaleV1APIServer) GetMachine( ctx context.Context, request *v1.GetMachineRequest, ) (*v1.GetMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } @@ -202,7 +204,7 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } @@ -216,7 +218,7 @@ func (api headscaleV1APIServer) SetTags( } } - err = api.h.SetTags(machine, request.GetTags()) + err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) if err != nil { return &v1.SetTagsResponse{ Machine: nil, @@ -248,12 +250,12 @@ func (api headscaleV1APIServer) DeleteMachine( ctx context.Context, request *v1.DeleteMachineRequest, ) (*v1.DeleteMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.DeleteMachine( + err = api.h.db.DeleteMachine( machine, ) if err != nil { @@ -267,12 +269,12 @@ func (api headscaleV1APIServer) ExpireMachine( ctx context.Context, request *v1.ExpireMachineRequest, ) (*v1.ExpireMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - api.h.ExpireMachine( + api.h.db.ExpireMachine( machine, ) @@ -288,12 +290,12 @@ func (api headscaleV1APIServer) RenameMachine( ctx context.Context, request *v1.RenameMachineRequest, ) (*v1.RenameMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.RenameMachine( + err = api.h.db.RenameMachine( machine, request.GetNewName(), ) @@ -314,7 +316,7 @@ func (api headscaleV1APIServer) ListMachines( request *v1.ListMachinesRequest, ) (*v1.ListMachinesResponse, error) { if request.GetUser() != "" { - machines, err := api.h.ListMachinesByUser(request.GetUser()) + machines, err := api.h.db.ListMachinesByUser(request.GetUser()) if err != nil { return nil, err } @@ -327,7 +329,7 @@ func (api headscaleV1APIServer) ListMachines( return &v1.ListMachinesResponse{Machines: response}, nil } - machines, err := api.h.ListMachines() + machines, err := api.h.db.ListMachines() if err != nil { return nil, err } @@ -352,12 +354,12 @@ func (api headscaleV1APIServer) MoveMachine( ctx context.Context, request *v1.MoveMachineRequest, ) (*v1.MoveMachineResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - err = api.h.SetMachineUser(machine, request.GetUser()) + err = api.h.db.SetMachineUser(machine, request.GetUser()) if err != nil { return nil, err } @@ -369,7 +371,7 @@ func (api headscaleV1APIServer) GetRoutes( ctx context.Context, request *v1.GetRoutesRequest, ) (*v1.GetRoutesResponse, error) { - routes, err := api.h.GetRoutes() + routes, err := api.h.db.GetRoutes() if err != nil { return nil, err } @@ -383,7 +385,7 @@ func (api headscaleV1APIServer) EnableRoute( ctx context.Context, request *v1.EnableRouteRequest, ) (*v1.EnableRouteResponse, error) { - err := api.h.EnableRoute(request.GetRouteId()) + err := api.h.db.EnableRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -395,7 +397,7 @@ func (api headscaleV1APIServer) DisableRoute( ctx context.Context, request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { - err := api.h.DisableRoute(request.GetRouteId()) + err := api.h.db.DisableRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -407,12 +409,12 @@ func (api headscaleV1APIServer) GetMachineRoutes( ctx context.Context, request *v1.GetMachineRoutesRequest, ) (*v1.GetMachineRoutesResponse, error) { - machine, err := api.h.GetMachineByID(request.GetMachineId()) + machine, err := api.h.db.GetMachineByID(request.GetMachineId()) if err != nil { return nil, err } - routes, err := api.h.GetMachineRoutes(machine) + routes, err := api.h.db.GetMachineRoutes(machine) if err != nil { return nil, err } @@ -426,7 +428,7 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - err := api.h.DeleteRoute(request.GetRouteId()) + err := api.h.db.DeleteRoute(request.GetRouteId()) if err != nil { return nil, err } @@ -443,7 +445,7 @@ func (api headscaleV1APIServer) CreateApiKey( expiration = request.GetExpiration().AsTime() } - apiKey, _, err := api.h.CreateAPIKey( + apiKey, _, err := api.h.db.CreateAPIKey( &expiration, ) if err != nil { @@ -460,12 +462,12 @@ func (api headscaleV1APIServer) ExpireApiKey( var apiKey *APIKey var err error - apiKey, err = api.h.GetAPIKey(request.Prefix) + apiKey, err = api.h.db.GetAPIKey(request.Prefix) if err != nil { return nil, err } - err = api.h.ExpireAPIKey(apiKey) + err = api.h.db.ExpireAPIKey(apiKey) if err != nil { return nil, err } @@ -477,7 +479,7 @@ func (api headscaleV1APIServer) ListApiKeys( ctx context.Context, request *v1.ListApiKeysRequest, ) (*v1.ListApiKeysResponse, error) { - apiKeys, err := api.h.ListAPIKeys() + apiKeys, err := api.h.db.ListAPIKeys() if err != nil { return nil, err } @@ -495,12 +497,12 @@ func (api headscaleV1APIServer) DebugCreateMachine( ctx context.Context, request *v1.DebugCreateMachineRequest, ) (*v1.DebugCreateMachineResponse, error) { - user, err := api.h.GetUser(request.GetUser()) + user, err := api.h.db.GetUser(request.GetUser()) if err != nil { return nil, err } - routes, err := stringToIPPrefix(request.GetRoutes()) + routes, err := util.StringToIPPrefix(request.GetRoutes()) if err != nil { return nil, err } @@ -517,7 +519,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( Hostname: "DebugTestMachine", } - givenName, err := api.h.GenerateGivenName(request.GetKey(), request.GetName()) + givenName, err := api.h.db.GenerateGivenName(request.GetKey(), request.GetName()) if err != nil { return nil, err } @@ -542,7 +544,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( } api.h.registrationCache.Set( - NodePublicKeyStripPrefix(nodeKey), + util.NodePublicKeyStripPrefix(nodeKey), newMachine, registerCacheExpiration, ) diff --git a/hscontrol/machine.go b/hscontrol/machine.go index 9f04d8ce30..846112b15b 100644 --- a/hscontrol/machine.go +++ b/hscontrol/machine.go @@ -11,6 +11,8 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "github.com/samber/lo" "go4.org/netipx" @@ -21,23 +23,23 @@ import ( ) const ( - ErrMachineNotFound = Error("machine not found") - ErrMachineRouteIsNotAvailable = Error("route is not available on machine") - ErrMachineAddressesInvalid = Error("failed to parse machine addresses") - ErrMachineNotFoundRegistrationCache = Error( - "machine not found in registration cache", - ) - ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface") - ErrHostnameTooLong = Error("Hostname too long") - ErrDifferentRegisteredUser = Error( - "machine was previously registered with a different user", - ) MachineGivenNameHashLength = 8 MachineGivenNameTrimSize = 2 + maxHostnameLength = 255 ) -const ( - maxHostnameLength = 255 +var ( + ErrMachineNotFound = errors.New("machine not found") + ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine") + ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") + ErrMachineNotFoundRegistrationCache = errors.New( + "machine not found in registration cache", + ) + ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface") + ErrHostnameTooLong = errors.New("hostname too long") + ErrDifferentRegisteredUser = errors.New( + "machine was previously registered with a different user", + ) ) // Machine is a Headscale client. @@ -188,8 +190,10 @@ func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine // filterMachinesByACL wrapper function to not have devs pass around locks and maps // related to the application outside of tests. -func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) Machines { - return filterMachinesByACL(currentMachine, peers, h.aclRules) +func (hsdb *HSDatabase) filterMachinesByACL( + aclRules []tailcfg.FilterRule, + currentMachine *Machine, peers Machines) Machines { + return filterMachinesByACL(currentMachine, peers, aclRules) } // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. @@ -213,14 +217,14 @@ func filterMachinesByACL( return result } -func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). Msg("Finding direct peers") machines := Machines{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", machine.NodeKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") @@ -237,23 +241,27 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) { return machines, nil } -func (h *Headscale) getPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) getPeers( + aclPolicy *ACLPolicy, + aclRules []tailcfg.FilterRule, + machine *Machine, +) (Machines, error) { var peers Machines var err error // If ACLs rules are defined, filter visible host list with the ACLs // else use the classic user scope - if h.aclPolicy != nil { + if aclPolicy != nil { var machines []Machine - machines, err = h.ListMachines() + machines, err = hsdb.ListMachines() if err != nil { log.Error().Err(err).Msg("Error retrieving list of machines") return Machines{}, err } - peers = h.filterMachinesByACL(machine, machines) + peers = hsdb.filterMachinesByACL(aclRules, machine, machines) } else { - peers, err = h.ListPeers(machine) + peers, err = hsdb.ListPeers(machine) if err != nil { log.Error(). Caller(). @@ -275,10 +283,14 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { return peers, nil } -func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) getValidPeers( + aclPolicy *ACLPolicy, + aclRules []tailcfg.FilterRule, + machine *Machine, +) (Machines, error) { validPeers := make(Machines, 0) - peers, err := h.getPeers(machine) + peers, err := hsdb.getPeers(aclPolicy, aclRules, machine) if err != nil { return Machines{}, err } @@ -292,18 +304,18 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) { return validPeers, nil } -func (h *Headscale) ListMachines() ([]Machine, error) { +func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { return nil, err } return machines, nil } -func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) { +func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) { machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { return nil, err } @@ -311,8 +323,8 @@ func (h *Headscale) ListMachinesByGivenName(givenName string) ([]Machine, error) } // GetMachine finds a Machine by name and user and returns the Machine struct. -func (h *Headscale) GetMachine(user string, name string) (*Machine, error) { - machines, err := h.ListMachinesByUser(user) +func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err } @@ -327,8 +339,8 @@ func (h *Headscale) GetMachine(user string, name string) (*Machine, error) { } // GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct. -func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machine, error) { - machines, err := h.ListMachinesByUser(user) +func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) { + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err } @@ -343,9 +355,9 @@ func (h *Headscale) GetMachineByGivenName(user string, givenName string) (*Machi } // GetMachineByID finds a Machine by ID and returns the Machine struct. -func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { +func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { m := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { return nil, result.Error } @@ -353,11 +365,11 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { } // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. -func (h *Headscale) GetMachineByMachineKey( +func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*Machine, error) { m := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } @@ -365,12 +377,12 @@ func (h *Headscale) GetMachineByMachineKey( } // GetMachineByNodeKey finds a Machine by its current NodeKey. -func (h *Headscale) GetMachineByNodeKey( +func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, ) (*Machine, error) { machine := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", - NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", + util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { return nil, result.Error } @@ -378,14 +390,14 @@ func (h *Headscale) GetMachineByNodeKey( } // GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct. -func (h *Headscale) GetMachineByAnyKey( +func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*Machine, error) { machine := Machine{} - if result := h.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", - MachinePublicKeyStripPrefix(machineKey), - NodePublicKeyStripPrefix(nodeKey), - NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { + if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", + util.MachinePublicKeyStripPrefix(machineKey), + util.NodePublicKeyStripPrefix(nodeKey), + util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { return nil, result.Error } @@ -394,8 +406,8 @@ func (h *Headscale) GetMachineByAnyKey( // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { - if result := h.db.Find(machine).First(&machine); result.Error != nil { +func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { + if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -403,20 +415,28 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { } // SetTags takes a Machine struct pointer and update the forced tags. -func (h *Headscale) SetTags(machine *Machine, tags []string) error { +func (hsdb *HSDatabase) SetTags( + machine *Machine, + tags []string, + // TODO(kradalby): This is a temporary measure to be able to detach the + // database completely from the global h. In the future, as part of this + // reorg, the rules will be generated on a per node basis, and not be prone + // to throwing error at save. + updateACL func() error) error { newTags := []string{} for _, tag := range tags { - if !contains(newTags, tag) { + if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } machine.ForcedTags = newTags - if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) { + if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) { return err } - h.setLastStateChangeToNow() - if err := h.db.Save(machine).Error; err != nil { + hsdb.notifyStateChange() + + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) } @@ -424,13 +444,13 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error { } // ExpireMachine takes a Machine struct and sets the expire field to now. -func (h *Headscale) ExpireMachine(machine *Machine) error { +func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { now := time.Now() machine.Expiry = &now - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to expire machine in the database: %w", err) } @@ -439,7 +459,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. -func (h *Headscale) RenameMachine(machine *Machine, newName string) error { +func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { err := CheckForFQDNRules( newName, ) @@ -455,9 +475,9 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error { } machine.GivenName = newName - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf("failed to rename machine in the database: %w", err) } @@ -465,15 +485,15 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error { } // RefreshMachine takes a Machine struct and sets the expire field to now. -func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error { +func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error { now := time.Now() machine.LastSuccessfulUpdate = &now machine.Expiry = &expiry - h.setLastStateChangeToNow() + hsdb.notifyStateChange() - if err := h.db.Save(machine).Error; err != nil { + if err := hsdb.db.Save(machine).Error; err != nil { return fmt.Errorf( "failed to refresh machine (update expiration) in the database: %w", err, @@ -484,21 +504,21 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error { } // DeleteMachine softs deletes a Machine from the database. -func (h *Headscale) DeleteMachine(machine *Machine) error { - err := h.DeleteMachineRoutes(machine) +func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err } - if err := h.db.Delete(&machine).Error; err != nil { + if err := hsdb.db.Delete(&machine).Error; err != nil { return err } return nil } -func (h *Headscale) TouchMachine(machine *Machine) error { - return h.db.Updates(Machine{ +func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { + return hsdb.db.Updates(Machine{ ID: machine.ID, LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, @@ -506,13 +526,13 @@ func (h *Headscale) TouchMachine(machine *Machine) error { } // HardDeleteMachine hard deletes a Machine from the database. -func (h *Headscale) HardDeleteMachine(machine *Machine) error { - err := h.DeleteMachineRoutes(machine) +func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err } - if err := h.db.Unscoped().Delete(&machine).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { return err } @@ -524,8 +544,8 @@ func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { return tailcfg.Hostinfo(machine.HostInfo) } -func (h *Headscale) isOutdated(machine *Machine) bool { - if err := h.UpdateMachineFromDatabase(machine); err != nil { +func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool { + if err := hsdb.UpdateMachineFromDatabase(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. return true @@ -536,7 +556,6 @@ func (h *Headscale) isOutdated(machine *Machine) bool { // TODO(kradalby): Only request updates from users where we can talk to nodes // This would mostly be for a bit of performance, and can be calculated based on // ACLs. - lastChange := h.getLastStateChange() lastUpdate := machine.CreatedAt if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -576,15 +595,16 @@ func (machines MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (h *Headscale) toNodes( +func (hsdb *HSDatabase) toNodes( machines Machines, + aclPolicy *ACLPolicy, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) for index, machine := range machines { - node, err := h.toNode(machine, baseDomain, dnsConfig) + node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig) if err != nil { return nil, err } @@ -597,13 +617,14 @@ func (h *Headscale) toNodes( // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. -func (h *Headscale) toNode( +func (hsdb *HSDatabase) toNode( machine Machine, + aclPolicy *ACLPolicy, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) (*tailcfg.Node, error) { var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey))) + err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) if err != nil { log.Trace(). Caller(). @@ -617,7 +638,7 @@ func (h *Headscale) toNode( // MachineKey is only used in the legacy protocol if machine.MachineKey != "" { err = machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil { return nil, fmt.Errorf("failed to parse machine public key: %w", err) @@ -627,7 +648,7 @@ func (h *Headscale) toNode( var discoKey key.DiscoPublic if machine.DiscoKey != "" { err := discoKey.UnmarshalText( - []byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), + []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), ) if err != nil { return nil, fmt.Errorf("failed to parse disco public key: %w", err) @@ -646,13 +667,13 @@ func (h *Headscale) toNode( []netip.Prefix{}, addrs...) // we append the node own IP, as it is required by the clients - primaryRoutes, err := h.getMachinePrimaryRoutes(&machine) + primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine) if err != nil { return nil, err } primaryPrefixes := Routes(primaryRoutes).toPrefixes() - machineRoutes, err := h.GetMachineRoutes(&machine) + machineRoutes, err := hsdb.GetMachineRoutes(&machine) if err != nil { return nil, err } @@ -699,13 +720,13 @@ func (h *Headscale) toNode( online := machine.isOnline() - tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain) + tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain) tags = lo.Uniq(append(tags, machine.ForcedTags...)) node := tailcfg.Node{ ID: tailcfg.NodeID(machine.ID), // this is the actual ID StableID: tailcfg.StableNodeID( - strconv.FormatUint(machine.ID, Base10), + strconv.FormatUint(machine.ID, util.Base10), ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, @@ -827,7 +848,8 @@ func getTags( return validTags, invalidTags } -func (h *Headscale) RegisterMachineFromAuthCallback( +func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( + cache *cache.Cache, nodeKeyStr string, userName string, machineExpiry *time.Time, @@ -846,9 +868,9 @@ func (h *Headscale) RegisterMachineFromAuthCallback( Str("expiresAt", fmt.Sprintf("%v", machineExpiry)). Msg("Registering machine from API/CLI or auth callback") - if machineInterface, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(nodeKey)); ok { + if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { if registrationMachine, ok := machineInterface.(Machine); ok { - user, err := h.GetUser(userName) + user, err := hsdb.GetUser(userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register machine from auth callback, %w", @@ -869,12 +891,12 @@ func (h *Headscale) RegisterMachineFromAuthCallback( registrationMachine.Expiry = machineExpiry } - machine, err := h.RegisterMachine( + machine, err := hsdb.RegisterMachine( registrationMachine, ) if err == nil { - h.registrationCache.Delete(nodeKeyStr) + cache.Delete(nodeKeyStr) } return machine, err @@ -887,7 +909,7 @@ func (h *Headscale) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine(machine Machine, +func (hsdb *HSDatabase) RegisterMachine(machine Machine, ) (*Machine, error) { log.Debug(). Str("machine", machine.Hostname). @@ -900,7 +922,7 @@ func (h *Headscale) RegisterMachine(machine Machine, // so we store the machine.Expire and machine.Nodekey that has been set when // adding it to the registrationCache if len(machine.IPAddresses) > 0 { - if err := h.db.Save(&machine).Error; err != nil { + if err := hsdb.db.Save(&machine).Error; err != nil { return nil, fmt.Errorf("failed register existing machine in the database: %w", err) } @@ -915,10 +937,10 @@ func (h *Headscale) RegisterMachine(machine Machine, return &machine, nil } - h.ipAllocationMutex.Lock() - defer h.ipAllocationMutex.Unlock() + hsdb.ipAllocationMutex.Lock() + defer hsdb.ipAllocationMutex.Unlock() - ips, err := h.getAvailableIPs() + ips, err := hsdb.getAvailableIPs() if err != nil { log.Error(). Caller(). @@ -931,7 +953,7 @@ func (h *Headscale) RegisterMachine(machine Machine, machine.IPAddresses = ips - if err := h.db.Save(&machine).Error; err != nil { + if err := hsdb.db.Save(&machine).Error; err != nil { return nil, fmt.Errorf("failed register(save) machine in the database: %w", err) } @@ -945,10 +967,10 @@ func (h *Headscale) RegisterMachine(machine Machine, } // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. -func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { +func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -970,10 +992,10 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error } // GetEnabledRoutes returns the routes that are enabled for the machine. -func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { +func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true). Find(&routes).Error @@ -995,13 +1017,13 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { return prefixes, nil } -func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { +func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := h.GetEnabledRoutes(machine) + enabledRoutes, err := hsdb.GetEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -1018,7 +1040,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { } // enableRoutes enables new routes based on a list of new routes. -func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { +func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -1029,13 +1051,13 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { newRoutes[index] = route } - advertisedRoutes, err := h.GetAdvertisedRoutes(machine) + advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) if err != nil { return err } for _, newRoute := range newRoutes { - if !contains(advertisedRoutes, newRoute) { + if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { return fmt.Errorf( "route (%s) is not available on node %s: %w", machine.Hostname, @@ -1047,7 +1069,7 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { route := Route{} - err := h.db.Preload("Machine"). + err := hsdb.db.Preload("Machine"). Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). First(&route).Error if err == nil { @@ -1056,10 +1078,10 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) if !route.isExitRoute() { - route.IsPrimary = h.isUniquePrefix(route) + route.IsPrimary = hsdb.isUniquePrefix(route) } - err = h.db.Save(&route).Error + err = hsdb.db.Save(&route).Error if err != nil { return fmt.Errorf("failed to enable route: %w", err) } @@ -1068,19 +1090,19 @@ func (h *Headscale) enableRoutes(machine *Machine, routeStrs ...string) error { } } - h.setLastStateChangeToNow() + hsdb.notifyStateChange() return nil } // EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. -func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { +func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error { if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } routes := []Route{} - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID). Find(&routes).Error @@ -1097,7 +1119,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { approvedRoutes := []Route{} for _, advertisedRoute := range routes { - routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers( + routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( netip.Prefix(advertisedRoute.Prefix), ) if err != nil { @@ -1113,7 +1135,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { if approvedAlias == machine.User.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - approvedIps, err := h.aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, h.cfg.OIDC.StripEmaildomain) + approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain) if err != nil { log.Err(err). Str("alias", approvedAlias). @@ -1132,7 +1154,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { for i, approvedRoute := range approvedRoutes { approvedRoutes[i].Enabled = true - err = h.db.Save(&approvedRoutes[i]).Error + err = hsdb.db.Save(&approvedRoutes[i]).Error if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). @@ -1146,10 +1168,10 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { return nil } -func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { +func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := NormalizeToFQDNRules( suppliedName, - h.cfg.OIDC.StripEmaildomain, + hsdb.stripEmailDomain, ) if err != nil { return "", err @@ -1162,7 +1184,7 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s normalizedHostname = normalizedHostname[:trimmedHostnameLength] } - suffix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength) + suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength) if err != nil { return "", err } @@ -1173,21 +1195,21 @@ func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (s return normalizedHostname, nil } -func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (string, error) { - givenName, err := h.generateGivenName(suppliedName, false) +func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { + givenName, err := hsdb.generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machines, err := h.ListMachinesByGivenName(givenName) + machines, err := hsdb.ListMachinesByGivenName(givenName) if err != nil { return "", err } for _, machine := range machines { if machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := h.generateGivenName(suppliedName, true) + postfixedName, err := hsdb.generateGivenName(suppliedName, true) if err != nil { return "", err } diff --git a/hscontrol/machine_test.go b/hscontrol/machine_test.go index 3f11da4b2b..0e7d7dea6a 100644 --- a/hscontrol/machine_test.go +++ b/hscontrol/machine_test.go @@ -9,19 +9,20 @@ import ( "testing" "time" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" ) func (s *Suite) TestGetMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -34,20 +35,20 @@ func (s *Suite) TestGetMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByID(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) machine := Machine{ @@ -60,20 +61,20 @@ func (s *Suite) TestGetMachineByID(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByNodeKey(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -81,28 +82,28 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) { machine := Machine{ ID: 0, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByNodeKey(nodeKey.Public()) + _, err = app.db.GetMachineByNodeKey(nodeKey.Public()) c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -112,22 +113,22 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { machine := Machine{ ID: 0, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) + _, err = app.db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) } func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) machine := Machine{ ID: 0, @@ -139,17 +140,17 @@ func (s *Suite) TestDeleteMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.DeleteMachine(&machine) + err = app.db.DeleteMachine(&machine) c.Assert(err, check.IsNil) - _, err = app.GetMachine(user.Name, "testmachine") + _, err = app.db.GetMachine(user.Name, "testmachine") c.Assert(err, check.NotNil) } func (s *Suite) TestHardDeleteMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) machine := Machine{ ID: 0, @@ -161,23 +162,23 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(1), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.HardDeleteMachine(&machine) + err = app.db.HardDeleteMachine(&machine) c.Assert(err, check.IsNil) - _, err = app.GetMachine(user.Name, "testmachine3") + _, err = app.db.GetMachine(user.Name, "testmachine3") c.Assert(err, check.NotNil) } func (s *Suite) TestListPeers(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachineByID(0) + _, err = app.db.GetMachineByID(0) c.Assert(err, check.NotNil) for index := 0; index <= 10; index++ { @@ -191,13 +192,13 @@ func (s *Suite) TestListPeers(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) } - machine0ByID, err := app.GetMachineByID(0) + machine0ByID, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) - peersOfMachine0, err := app.ListPeers(machine0ByID) + peersOfMachine0, err := app.db.ListPeers(machine0ByID) c.Assert(err, check.IsNil) c.Assert(len(peersOfMachine0), check.Equals, 9) @@ -215,14 +216,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { stor := make([]base, 0) for _, name := range []string{"test", "admin"} { - user, err := app.CreateUser(name) + user, err := app.db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } - _, err := app.GetMachineByID(0) + _, err := app.db.GetMachineByID(0) c.Assert(err, check.NotNil) for index := 0; index <= 10; index++ { @@ -239,7 +240,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) } app.aclPolicy = &ACLPolicy{ @@ -266,19 +267,19 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { err = app.UpdateACLRules() c.Assert(err, check.IsNil) - adminMachine, err := app.GetMachineByID(1) + adminMachine, err := app.db.GetMachineByID(1) c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) c.Assert(err, check.IsNil) - testMachine, err := app.GetMachineByID(2) + testMachine, err := app.db.GetMachineByID(2) c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) c.Assert(err, check.IsNil) - machines, err := app.ListMachines() + machines, err := app.db.ListMachines() c.Assert(err, check.IsNil) - peersOfTestMachine := app.filterMachinesByACL(testMachine, machines) - peersOfAdminMachine := app.filterMachinesByACL(adminMachine, machines) + peersOfTestMachine := app.db.filterMachinesByACL(app.aclRules, testMachine, machines) + peersOfAdminMachine := app.db.filterMachinesByACL(app.aclRules, adminMachine, machines) c.Log(peersOfTestMachine) c.Assert(len(peersOfTestMachine), check.Equals, 9) @@ -294,13 +295,13 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { } func (s *Suite) TestExpireMachine(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -314,15 +315,15 @@ func (s *Suite) TestExpireMachine(c *check.C) { AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, } - app.db.Save(machine) + app.db.db.Save(machine) - machineFromDB, err := app.GetMachine("test", "testmachine") + machineFromDB, err := app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert(machineFromDB, check.NotNil) c.Assert(machineFromDB.isExpired(), check.Equals, false) - err = app.ExpireMachine(machineFromDB) + err = app.db.ExpireMachine(machineFromDB) c.Assert(err, check.IsNil) c.Assert(machineFromDB.isExpired(), check.Equals, true) @@ -350,13 +351,13 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { } func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := app.CreateUser("user-1") + user1, err := app.db.CreateUser("user-1") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user1.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user1.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("user-1", "testmachine") + _, err = app.db.GetMachine("user-1", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -370,37 +371,37 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) - givenName, err := app.GenerateGivenName("machine-key-2", "hostname-2") + givenName, err := app.db.GenerateGivenName("machine-key-2", "hostname-2") comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Equals, "hostname-2", comment) - givenName, err = app.GenerateGivenName("machine-key-1", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-1", "hostname-1") comment = check.Commentf("Same user, same machine, same hostname, no conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Equals, "hostname-1", comment) - givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") comment = check.Commentf("Same user, unique machines, same hostname, conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) - givenName, err = app.GenerateGivenName("machine-key-2", "hostname-1") + givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") comment = check.Commentf("Unique users, unique machines, same hostname, conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) } func (s *Suite) TestSetTags(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "testmachine") + _, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) machine := &Machine{ @@ -413,21 +414,21 @@ func (s *Suite) TestSetTags(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(machine) + app.db.db.Save(machine) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = app.SetTags(machine, sTags) + err = app.db.SetTags(machine, sTags, app.UpdateACLRules) c.Assert(err, check.IsNil) - machine, err = app.GetMachine("test", "testmachine") + machine, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) // assign duplicat tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = app.SetTags(machine, eTags) + err = app.db.SetTags(machine, eTags, app.UpdateACLRules) c.Assert(err, check.IsNil) - machine, err = app.GetMachine("test", "testmachine") + machine, err = app.db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) c.Assert( machine.ForcedTags, @@ -562,7 +563,7 @@ func Test_getTags(t *testing.T) { test.args.stripEmailDomain, ) for _, valid := range gotValid { - if !contains(test.wantValid, valid) { + if !util.StringOrPrefixListContains(test.wantValid, valid) { t.Errorf( "valids: getTags() = %v, want %v", gotValid, @@ -573,7 +574,7 @@ func Test_getTags(t *testing.T) { } } for _, invalid := range gotInvalid { - if !contains(test.wantInvalid, invalid) { + if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { t.Errorf( "invalids: getTags() = %v, want %v", gotInvalid, @@ -1061,19 +1062,15 @@ func TestHeadscale_generateGivenName(t *testing.T) { } tests := []struct { name string - h *Headscale + db *HSDatabase args args want *regexp.Regexp wantErr bool }{ { name: "simple machine name generation", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "testmachine", @@ -1084,12 +1081,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", @@ -1100,12 +1093,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", @@ -1116,12 +1105,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", @@ -1132,12 +1117,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", @@ -1148,12 +1129,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "test", @@ -1164,12 +1141,8 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - h: &Headscale{ - cfg: &Config{ - OIDC: OIDCConfig{ - StripEmaildomain: true, - }, - }, + db: &HSDatabase{ + stripEmailDomain: true, }, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", @@ -1181,7 +1154,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.h.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) if (err != nil) != tt.wantErr { t.Errorf( "Headscale.GenerateGivenName() error = %v, wantErr %v", @@ -1214,35 +1187,35 @@ func TestHeadscale_generateGivenName(t *testing.T) { func (s *Suite) TestAutoApproveRoutes(c *check.C) { acl := []byte(` { - "tagOwners": { - "tag:exit": ["test"], - }, - - "groups": { - "group:test": ["test"] - }, - - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], - - "autoApprovers": { - "exitNode": ["tag:exit"], - "routes": { - "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], - } - } + "tagOwners": { + "tag:exit": ["test"], + }, + + "groups": { + "group:test": ["test"] + }, + + "acls": [ + {"action": "accept", "users": ["*"], "ports": ["*:*"]}, + ], + + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test"], + } + } } `) err := app.LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) nodeKey := key.NewNode() @@ -1255,7 +1228,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { machine := Machine{ ID: 0, MachineKey: "foo", - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: "faa", Hostname: "test", UserID: user.ID, @@ -1268,18 +1241,18 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - machine0ByID, err := app.GetMachineByID(0) + machine0ByID, err := app.db.GetMachineByID(0) c.Assert(err, check.IsNil) - err = app.EnableAutoApprovedRoutes(machine0ByID) + err = app.db.EnableAutoApprovedRoutes(app.aclPolicy, machine0ByID) c.Assert(err, check.IsNil) - enabledRoutes, err := app.GetEnabledRoutes(machine0ByID) + enabledRoutes, err := app.db.GetEnabledRoutes(machine0ByID) c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 3) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 332ce099de..c666594e5e 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -14,6 +14,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" "tailscale.com/types/key" @@ -21,16 +22,22 @@ import ( const ( randomByteSize = 16 +) - errEmptyOIDCCallbackParams = Error("empty OIDC callback params") - errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback") - errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") - errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group") - errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") - errOIDCInvalidMachineState = Error( +var ( + errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") + errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") + errOIDCAllowedDomains = errors.New( + "authenticated principal does not match any allowed domain", + ) + errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group") + errOIDCAllowedUsers = errors.New( + "authenticated principal does not match any allowed user", + ) + errOIDCInvalidMachineState = errors.New( "requested machine state key expired before authorisation completed", ) - errOIDCNodeKeyMissing = Error("could not get node key from cache") + errOIDCNodeKeyMissing = errors.New("could not get node key from cache") ) type IDTokenClaims struct { @@ -94,7 +101,7 @@ func (h *Headscale) RegisterOIDC( Bool("ok", ok). Msg("Received oidc register call") - if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -115,7 +122,7 @@ func (h *Headscale) RegisterOIDC( // the template and log an error. var nodeKey key.NodePublic err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyStr)), ) if !ok || nodeKeyStr == "" || err != nil { @@ -149,7 +156,11 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] // place the node key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration) + h.registrationCache.Set( + stateStr, + util.NodePublicKeyStripPrefix(nodeKey), + registerCacheExpiration, + ) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -406,7 +417,7 @@ func validateOIDCAllowedDomains( ) error { if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || - !IsStringInSlice(allowedDomains, claims.Email[at+1:]) { + !util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) { log.Error().Msg("authenticated principal does not match any allowed domain") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) @@ -436,7 +447,7 @@ func validateOIDCAllowedGroups( ) error { if len(allowedGroups) > 0 { for _, group := range allowedGroups { - if IsStringInSlice(claims.Groups, group) { + if util.IsStringInSlice(claims.Groups, group) { return nil } } @@ -466,7 +477,7 @@ func validateOIDCAllowedUsers( claims *IDTokenClaims, ) error { if len(allowedUsers) > 0 && - !IsStringInSlice(allowedUsers, claims.Email) { + !util.IsStringInSlice(allowedUsers, claims.Email) { log.Error().Msg("authenticated principal does not match any allowed user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) @@ -531,7 +542,7 @@ func (h *Headscale) validateMachineForOIDCCallback( } err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), + []byte(util.NodePublicKeyEnsurePrefix(nodeKeyFromCache)), ) if err != nil { log.Error(). @@ -555,7 +566,7 @@ func (h *Headscale) validateMachineForOIDCCallback( // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByNodeKey(nodeKey) + machine, _ := h.db.GetMachineByNodeKey(nodeKey) if machine != nil { log.Trace(). @@ -563,7 +574,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - err := h.RefreshMachine(machine, expiry) + err := h.db.RefreshMachine(machine, expiry) if err != nil { log.Error(). Caller(). @@ -653,9 +664,9 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer http.ResponseWriter, userName string, ) (*User, error) { - user, err := h.GetUser(userName) + user, err := h.db.GetUser(userName) if errors.Is(err, ErrUserNotFound) { - user, err = h.CreateUser(userName) + user, err = h.db.CreateUser(userName) if err != nil { log.Error(). @@ -702,7 +713,9 @@ func (h *Headscale) registerMachineForOIDCCallback( nodeKey *key.NodePublic, expiry time.Time, ) error { - if _, err := h.RegisterMachineFromAuthCallback( + if _, err := h.db.RegisterMachineFromAuthCallback( + // TODO(kradalby): find a better way to use the cache across modules + h.registrationCache, nodeKey.String(), user.Name, &expiry, diff --git a/hscontrol/preauth_keys.go b/hscontrol/preauth_keys.go index 6cff90b001..1956762270 100644 --- a/hscontrol/preauth_keys.go +++ b/hscontrol/preauth_keys.go @@ -10,16 +10,17 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) -const ( - ErrPreAuthKeyNotFound = Error("AuthKey not found") - ErrPreAuthKeyExpired = Error("AuthKey expired") - ErrSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") - ErrUserMismatch = Error("user mismatch") - ErrPreAuthKeyACLTagInvalid = Error("AuthKey tag is invalid") +var ( + ErrPreAuthKeyNotFound = errors.New("AuthKey not found") + ErrPreAuthKeyExpired = errors.New("AuthKey expired") + ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") + ErrUserMismatch = errors.New("user mismatch") + ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) // PreAuthKey describes a pre-authorization key usable in a particular user. @@ -45,26 +46,30 @@ type PreAuthKeyACLTag struct { } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. -func (h *Headscale) CreatePreAuthKey( +func (hsdb *HSDatabase) CreatePreAuthKey( userName string, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*PreAuthKey, error) { - user, err := h.GetUser(userName) + user, err := hsdb.GetUser(userName) if err != nil { return nil, err } for _, tag := range aclTags { if !strings.HasPrefix(tag, "tag:") { - return nil, fmt.Errorf("%w: '%s' did not begin with 'tag:'", ErrPreAuthKeyACLTagInvalid, tag) + return nil, fmt.Errorf( + "%w: '%s' did not begin with 'tag:'", + ErrPreAuthKeyACLTagInvalid, + tag, + ) } } now := time.Now().UTC() - kstr, err := h.generateKey() + kstr, err := hsdb.generateKey() if err != nil { return nil, err } @@ -79,7 +84,7 @@ func (h *Headscale) CreatePreAuthKey( Expiration: expiration, } - err = h.db.Transaction(func(db *gorm.DB) error { + err = hsdb.db.Transaction(func(db *gorm.DB) error { if err := db.Save(&key).Error; err != nil { return fmt.Errorf("failed to create key in the database: %w", err) } @@ -111,14 +116,14 @@ func (h *Headscale) CreatePreAuthKey( } // ListPreAuthKeys returns the list of PreAuthKeys for a user. -func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { - user, err := h.GetUser(userName) +func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { + user, err := hsdb.GetUser(userName) if err != nil { return nil, err } keys := []PreAuthKey{} - if err := h.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -126,8 +131,8 @@ func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { } // GetPreAuthKey returns a PreAuthKey for a given key. -func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { - pak, err := h.checkKeyValidity(key) +func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { + pak, err := hsdb.checkKeyValidity(key) if err != nil { return nil, err } @@ -141,8 +146,8 @@ func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { - return h.db.Transaction(func(db *gorm.DB) error { +func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { + return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -156,8 +161,8 @@ func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { } // MarkExpirePreAuthKey marks a PreAuthKey as expired. -func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { - if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { +func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { + if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -165,9 +170,9 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { k.Used = true - if err := h.db.Save(k).Error; err != nil { + if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } @@ -176,9 +181,9 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error { // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { +func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { pak := PreAuthKey{} - if result := h.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -194,7 +199,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { } machines := []Machine{} - if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { return nil, err } @@ -205,7 +210,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { return &pak, nil } -func (h *Headscale) generateKey() (string, error) { +func (hsdb *HSDatabase) generateKey() (string, error) { size := 24 bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { @@ -218,7 +223,7 @@ func (h *Headscale) generateKey() (string, error) { func (key *PreAuthKey) toProto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ User: key.User.Name, - Id: strconv.FormatUint(key.ID, Base10), + Id: strconv.FormatUint(key.ID, util.Base10), Key: key.Key, Ephemeral: key.Ephemeral, Reusable: key.Reusable, diff --git a/hscontrol/preauth_keys_test.go b/hscontrol/preauth_keys_test.go index bd383cfd25..a85a6c6103 100644 --- a/hscontrol/preauth_keys_test.go +++ b/hscontrol/preauth_keys_test.go @@ -7,14 +7,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := app.CreatePreAuthKey("bogus", true, false, nil, nil) + _, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) c.Assert(err, check.NotNil) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -24,10 +24,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { // Make sure the User association is populated c.Assert(key.User.Name, check.Equals, user.Name) - _, err = app.ListPreAuthKeys("bogus") + _, err = app.db.ListPreAuthKeys("bogus") c.Assert(err, check.NotNil) - keys, err := app.ListPreAuthKeys(user.Name) + keys, err := app.db.ListPreAuthKeys(user.Name) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) @@ -36,41 +36,41 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { } func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := app.CreateUser("test2") + user, err := app.db.CreateUser("test2") c.Assert(err, check.IsNil) now := time.Now() - pak, err := app.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := app.checkKeyValidity("potatoKey") + key, err := app.db.checkKeyValidity("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) c.Assert(key, check.IsNil) } func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := app.CreateUser("test3") + user, err := app.db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := app.CreateUser("test4") + user, err := app.db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -83,18 +83,18 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(key, check.IsNil) } func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := app.CreateUser("test5") + user, err := app.db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -107,30 +107,30 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := app.CreateUser("test6") + user, err := app.db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestEphemeralKey(c *check.C) { - user, err := app.CreateUser("test7") + user, err := app.db.CreateUser("test7") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, true, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) c.Assert(err, check.IsNil) now := time.Now() @@ -145,65 +145,65 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - _, err = app.checkKeyValidity(pak.Key) + _, err = app.db.checkKeyValidity(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = app.GetMachine("test7", "testest") + _, err = app.db.GetMachine("test7", "testest") c.Assert(err, check.IsNil) app.expireEphemeralNodesWorker() // The machine record should have been deleted - _, err = app.GetMachine("test7", "testest") + _, err = app.db.GetMachine("test7", "testest") c.Assert(err, check.NotNil) } func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := app.CreateUser("test3") + user, err := app.db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) - err = app.ExpirePreAuthKey(pak) + err = app.db.ExpirePreAuthKey(pak) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.NotNil) - key, err := app.checkKeyValidity(pak.Key) + key, err := app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := app.CreateUser("test6") + user, err := app.db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - app.db.Save(&pak) + app.db.db.Save(&pak) - _, err = app.checkKeyValidity(pak.Key) + _, err = app.db.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) } func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := app.CreateUser("test8") + user, err := app.db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = app.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = app.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := app.ListPreAuthKeys("test8") + listedPaks, err := app.db.ListPreAuthKeys("test8") c.Assert(err, check.IsNil) c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) } diff --git a/hscontrol/protocol_common.go b/hscontrol/protocol_common.go index 97da464bb9..5cd0ddb4e8 100644 --- a/hscontrol/protocol_common.go +++ b/hscontrol/protocol_common.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -82,7 +83,7 @@ func (h *Headscale) KeyHandler( // Old clients don't send a 'v' parameter, so we send the legacy public key writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public()))) + _, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public()))) if err != nil { log.Error(). Caller(). @@ -102,7 +103,7 @@ func (h *Headscale) handleRegisterCommon( isNoise bool, ) { now := time.Now().UTC() - machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + machine, err := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if errors.Is(err, gorm.ErrRecordNotFound) { // If the machine has AuthKey set, handle registration via PreAuthKeys if registerRequest.Auth.AuthKey != "" { @@ -120,7 +121,7 @@ func (h *Headscale) handleRegisterCommon( // is that the client will hammer headscale with requests until it gets a // successful RegisterResponse. if registerRequest.Followup != "" { - if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { + if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { log.Debug(). Caller(). Str("machine", registerRequest.Hostinfo.Hostname). @@ -152,7 +153,7 @@ func (h *Headscale) handleRegisterCommon( Bool("noise", isNoise). Msg("New machine not yet in the database") - givenName, err := h.GenerateGivenName( + givenName, err := h.db.GenerateGivenName( machineKey.String(), registerRequest.Hostinfo.Hostname, ) @@ -171,10 +172,10 @@ func (h *Headscale) handleRegisterCommon( // We create the machine and then keep it around until a callback // happens newMachine := Machine{ - MachineKey: MachinePublicKeyStripPrefix(machineKey), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, - NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey), + NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey), LastSeen: &now, Expiry: &time.Time{}, } @@ -210,11 +211,11 @@ func (h *Headscale) handleRegisterCommon( // So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it. var storedMachineKey key.MachinePublic err = storedMachineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil || storedMachineKey.IsZero() { - machine.MachineKey = MachinePublicKeyStripPrefix(machineKey) - if err := h.db.Save(&machine).Error; err != nil { + machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) + if err := h.db.db.Save(&machine).Error; err != nil { log.Error(). Caller(). Str("func", "RegistrationHandler"). @@ -231,7 +232,7 @@ func (h *Headscale) handleRegisterCommon( // - Trying to log out (sending a expiry in the past) // - A valid, registered machine, looking for /map // - Expired machine wanting to reauthenticate - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { + if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !registerRequest.Expiry.IsZero() && @@ -251,7 +252,7 @@ func (h *Headscale) handleRegisterCommon( } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && + if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && !machine.isExpired() { h.handleMachineRefreshKeyCommon( writer, @@ -282,9 +283,9 @@ func (h *Headscale) handleRegisterCommon( // we need to make sure the NodeKey matches the one in the request // TODO(juan): What happens when using fast user switching between two // headscale-managed tailnets? - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) h.registrationCache.Set( - NodePublicKeyStripPrefix(registerRequest.NodeKey), + util.NodePublicKeyStripPrefix(registerRequest.NodeKey), *machine, registerCacheExpiration, ) @@ -311,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) + pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). @@ -372,13 +373,13 @@ func (h *Headscale) handleAuthKeyCommon( Str("machine", registerRequest.Hostinfo.Hostname). Msg("Authentication key was valid, proceeding to acquire IP addresses") - nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) + nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey) // retrieve machine information if it exist // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + machine, _ := h.db.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if machine != nil { log.Trace(). Caller(). @@ -388,7 +389,7 @@ func (h *Headscale) handleAuthKeyCommon( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - err := h.RefreshMachine(machine, registerRequest.Expiry) + err := h.db.RefreshMachine(machine, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -403,7 +404,7 @@ func (h *Headscale) handleAuthKeyCommon( aclTags := pak.toProto().AclTags if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.SetTags(machine, aclTags) + err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) if err != nil { log.Error(). @@ -420,7 +421,7 @@ func (h *Headscale) handleAuthKeyCommon( } else { now := time.Now().UTC() - givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) + givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). Caller(). @@ -436,7 +437,7 @@ func (h *Headscale) handleAuthKeyCommon( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, - MachineKey: MachinePublicKeyStripPrefix(machineKey), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), RegisterMethod: RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, @@ -445,7 +446,7 @@ func (h *Headscale) handleAuthKeyCommon( ForcedTags: pak.toProto().AclTags, } - machine, err = h.RegisterMachine( + machine, err = h.db.RegisterMachine( machineToRegister, ) if err != nil { @@ -462,7 +463,7 @@ func (h *Headscale) handleAuthKeyCommon( } } - err = h.UsePreAuthKey(pak) + err = h.db.UsePreAuthKey(pak) if err != nil { log.Error(). Caller(). @@ -591,7 +592,7 @@ func (h *Headscale) handleMachineLogOutCommon( Str("machine", machine.Hostname). Msg("Client requested logout") - err := h.ExpireMachine(&machine) + err := h.db.ExpireMachine(&machine) if err != nil { log.Error(). Caller(). @@ -634,7 +635,7 @@ func (h *Headscale) handleMachineLogOutCommon( } if machine.isEphemeral() { - err = h.HardDeleteMachine(&machine) + err = h.db.HardDeleteMachine(&machine) if err != nil { log.Error(). Err(err). @@ -720,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( Bool("noise", isNoise). Str("machine", machine.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) + machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) - if err := h.db.Save(&machine).Error; err != nil { + if err := h.db.db.Save(&machine).Error; err != nil { log.Error(). Caller(). Err(err). diff --git a/hscontrol/protocol_common_poll.go b/hscontrol/protocol_common_poll.go index f267c9999a..502c633a8f 100644 --- a/hscontrol/protocol_common_poll.go +++ b/hscontrol/protocol_common_poll.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -29,10 +30,10 @@ func (h *Headscale) handlePollCommon( ) { machine.Hostname = mapRequest.Hostinfo.Hostname machine.HostInfo = HostInfo(*mapRequest.Hostinfo) - machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) now := time.Now().UTC() - err := h.processMachineRoutes(machine) + err := h.db.processMachineRoutes(machine) if err != nil { log.Error(). Caller(). @@ -53,7 +54,7 @@ func (h *Headscale) handlePollCommon( } // update routes with peer information - err = h.EnableAutoApprovedRoutes(machine) + err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) if err != nil { log.Error(). Caller(). @@ -77,7 +78,7 @@ func (h *Headscale) handlePollCommon( machine.LastSeen = &now } - if err := h.db.Updates(machine).Error; err != nil { + if err := h.db.db.Updates(machine).Error; err != nil { if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -325,7 +326,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -346,7 +347,7 @@ func (h *Headscale) pollNetMapStream( Set(float64(now.Unix())) machine.LastSuccessfulUpdate = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -409,7 +410,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -425,7 +426,7 @@ func (h *Headscale) pollNetMapStream( } now := time.Now().UTC() machine.LastSeen = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -456,7 +457,7 @@ func (h *Headscale) pollNetMapStream( updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). Inc() - if h.isOutdated(machine) { + if h.db.isOutdated(machine, h.getLastStateChange()) { var lastUpdate time.Time if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -524,7 +525,7 @@ func (h *Headscale) pollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachineFromDatabase(machine) + err = h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -544,7 +545,7 @@ func (h *Headscale) pollNetMapStream( Set(float64(now.Unix())) machine.LastSuccessfulUpdate = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -578,7 +579,7 @@ func (h *Headscale) pollNetMapStream( // TODO: Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err := h.UpdateMachineFromDatabase(machine) + err := h.db.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -594,7 +595,7 @@ func (h *Headscale) pollNetMapStream( } now := time.Now().UTC() machine.LastSeen = &now - err = h.TouchMachine(machine) + err = h.db.TouchMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). diff --git a/hscontrol/protocol_common_utils.go b/hscontrol/protocol_common_utils.go index e05b04a24a..1dababa1ab 100644 --- a/hscontrol/protocol_common_utils.go +++ b/hscontrol/protocol_common_utils.go @@ -5,6 +5,7 @@ import ( "encoding/json" "sync" + "github.com/juanfont/headscale/hscontrol/util" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "tailscale.com/smallzstd" @@ -27,7 +28,7 @@ func (h *Headscale) getMapResponseData( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) if err != nil { log.Error(). Caller(). @@ -50,11 +51,16 @@ func (h *Headscale) getMapKeepAliveResponseData( } if isNoise { - return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress, isNoise) + return h.marshalMapResponse( + keepAliveResponse, + key.MachinePublic{}, + mapRequest.Compress, + isNoise, + ) } var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) if err != nil { log.Error(). Caller(). @@ -104,7 +110,7 @@ func (h *Headscale) marshalMapResponse( } var respBody []byte - if compression == ZstdCompression { + if compression == util.ZstdCompression { respBody = zstdEncode(jsonBody) if !isNoise { // if legacy protocol respBody = h.privateKey.SealTo(machineKey, respBody) diff --git a/hscontrol/protocol_legacy.go b/hscontrol/protocol_legacy.go index 6712828631..f443ebad5f 100644 --- a/hscontrol/protocol_legacy.go +++ b/hscontrol/protocol_legacy.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -32,7 +33,7 @@ func (h *Headscale) RegistrationHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Caller(). @@ -44,7 +45,7 @@ func (h *Headscale) RegistrationHandler( return } registerRequest := tailcfg.RegisterRequest{} - err = decode(body, ®isterRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/protocol_legacy_poll.go b/hscontrol/protocol_legacy_poll.go index 0121bf3f1d..3755faf1b0 100644 --- a/hscontrol/protocol_legacy_poll.go +++ b/hscontrol/protocol_legacy_poll.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -44,7 +45,7 @@ func (h *Headscale) PollNetMapHandler( body, _ := io.ReadAll(req.Body) var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr))) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -56,7 +57,7 @@ func (h *Headscale) PollNetMapHandler( return } mapRequest := tailcfg.MapRequest{} - err = decode(body, &mapRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -67,7 +68,7 @@ func (h *Headscale) PollNetMapHandler( return } - machine, err := h.GetMachineByMachineKey(machineKey) + machine, err := h.db.GetMachineByMachineKey(machineKey) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). diff --git a/hscontrol/protocol_noise_poll.go b/hscontrol/protocol_noise_poll.go index 38f2b1c7e9..c0790f978e 100644 --- a/hscontrol/protocol_noise_poll.go +++ b/hscontrol/protocol_noise_poll.go @@ -48,7 +48,11 @@ func (ns *noiseServer) NoisePollNetMapHandler( ns.nodeKey = mapRequest.NodeKey - machine, err := ns.headscale.GetMachineByAnyKey(ns.conn.Peer(), mapRequest.NodeKey, key.NodePublic{}) + machine, err := ns.headscale.db.GetMachineByAnyKey( + ns.conn.Peer(), + mapRequest.NodeKey, + key.NodePublic{}, + ) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). diff --git a/hscontrol/routes.go b/hscontrol/routes.go index 89f9a6941b..e3be2f691a 100644 --- a/hscontrol/routes.go +++ b/hscontrol/routes.go @@ -11,13 +11,10 @@ import ( "gorm.io/gorm" ) -const ( - ErrRouteIsNotAvailable = Error("route is not available") -) - var ( - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") + ErrRouteIsNotAvailable = errors.New("route is not available") + ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") + ExitRouteV6 = netip.MustParsePrefix("::/0") ) type Route struct { @@ -51,9 +48,9 @@ func (rs Routes) toPrefixes() []netip.Prefix { return prefixes } -func (h *Headscale) GetRoutes() ([]Route, error) { +func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { var routes []Route - err := h.db.Preload("Machine").Find(&routes).Error + err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { return nil, err } @@ -61,9 +58,9 @@ func (h *Headscale) GetRoutes() ([]Route, error) { return routes, nil } -func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) { +func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ?", m.ID). Find(&routes).Error @@ -74,9 +71,9 @@ func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (h *Headscale) GetRoute(id uint64) (*Route, error) { +func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { var route Route - err := h.db.Preload("Machine").First(&route, id).Error + err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { return nil, err } @@ -84,8 +81,8 @@ func (h *Headscale) GetRoute(id uint64) (*Route, error) { return &route, nil } -func (h *Headscale) EnableRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) EnableRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -94,14 +91,14 @@ func (h *Headscale) EnableRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if route.isExitRoute() { - return h.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) + return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) } - return h.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) + return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) } -func (h *Headscale) DisableRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) DisableRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -112,15 +109,15 @@ func (h *Headscale) DisableRoute(id uint64) error { if !route.isExitRoute() { route.Enabled = false route.IsPrimary = false - err = h.db.Save(route).Error + err = hsdb.db.Save(route).Error if err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := h.GetMachineRoutes(&route.Machine) + routes, err := hsdb.GetMachineRoutes(&route.Machine) if err != nil { return err } @@ -129,18 +126,18 @@ func (h *Headscale) DisableRoute(id uint64) error { if routes[i].isExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false - err = h.db.Save(&routes[i]).Error + err = hsdb.db.Save(&routes[i]).Error if err != nil { return err } } } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (h *Headscale) DeleteRoute(id uint64) error { - route, err := h.GetRoute(id) +func (hsdb *HSDatabase) DeleteRoute(id uint64) error { + route, err := hsdb.GetRoute(id) if err != nil { return err } @@ -149,14 +146,14 @@ func (h *Headscale) DeleteRoute(id uint64) error { // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.isExitRoute() { - if err := h.db.Unscoped().Delete(&route).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := h.GetMachineRoutes(&route.Machine) + routes, err := hsdb.GetMachineRoutes(&route.Machine) if err != nil { return err } @@ -168,32 +165,32 @@ func (h *Headscale) DeleteRoute(id uint64) error { } } - if err := h.db.Unscoped().Delete(&routesToDelete).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { return err } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (h *Headscale) DeleteMachineRoutes(m *Machine) error { - routes, err := h.GetMachineRoutes(m) +func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { + routes, err := hsdb.GetMachineRoutes(m) if err != nil { return err } for i := range routes { - if err := h.db.Unscoped().Delete(&routes[i]).Error; err != nil { + if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { return err } } - return h.handlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. -func (h *Headscale) isUniquePrefix(route Route) bool { +func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { var count int64 - h.db. + hsdb.db. Model(&Route{}). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, @@ -203,9 +200,9 @@ func (h *Headscale) isUniquePrefix(route Route) bool { return count == 0 } -func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { +func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { var route Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). First(&route).Error @@ -222,9 +219,9 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { +func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). Find(&routes).Error @@ -235,9 +232,9 @@ func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (h *Headscale) processMachineRoutes(machine *Machine) error { +func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { currentRoutes := []Route{} - err := h.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error + err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err } @@ -251,7 +248,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { currentRoutes[pos].Advertised = true - err := h.db.Save(¤tRoutes[pos]).Error + err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { return err } @@ -260,7 +257,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { } else if route.Advertised { currentRoutes[pos].Advertised = false currentRoutes[pos].Enabled = false - err := h.db.Save(¤tRoutes[pos]).Error + err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { return err } @@ -275,7 +272,7 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { Advertised: true, Enabled: false, } - err := h.db.Create(&route).Error + err := hsdb.db.Create(&route).Error if err != nil { return err } @@ -285,10 +282,10 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { return nil } -func (h *Headscale) handlePrimarySubnetFailover() error { +func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { // first, get all the enabled routes var routes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("advertised = ? AND enabled = ?", true, true). Find(&routes).Error @@ -303,14 +300,14 @@ func (h *Headscale) handlePrimarySubnetFailover() error { } if !route.IsPrimary { - _, err := h.getPrimaryRoute(netip.Prefix(route.Prefix)) - if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { + _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) + if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { log.Info(). Str("prefix", netip.Prefix(route.Prefix).String()). Str("machine", route.Machine.GivenName). Msg("Setting primary route") routes[pos].IsPrimary = true - err := h.db.Save(&routes[pos]).Error + err := hsdb.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error marking route as primary") @@ -336,7 +333,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // find a new primary route var newPrimaryRoutes []Route - err := h.db. + err := hsdb.db. Preload("Machine"). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, @@ -375,7 +372,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // disable the old primary route routes[pos].IsPrimary = false - err = h.db.Save(&routes[pos]).Error + err = hsdb.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error disabling old primary route") @@ -384,7 +381,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { // enable the new primary route newPrimaryRoute.IsPrimary = true - err = h.db.Save(&newPrimaryRoute).Error + err = hsdb.db.Save(&newPrimaryRoute).Error if err != nil { log.Error().Err(err).Msg("error enabling new primary route") @@ -396,7 +393,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { } if routesChanged { - h.setLastStateChangeToNow() + hsdb.notifyStateChange() } return nil diff --git a/hscontrol/routes_test.go b/hscontrol/routes_test.go index 1e5e2bbf7b..cf437a4d20 100644 --- a/hscontrol/routes_test.go +++ b/hscontrol/routes_test.go @@ -4,19 +4,20 @@ import ( "net/netip" "time" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" ) func (s *Suite) TestGetRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_get_route_machine") + _, err = app.db.GetMachine("test", "test_get_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -37,30 +38,30 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - advertisedRoutes, err := app.GetAdvertisedRoutes(&machine) + advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = app.enableRoutes(&machine, "192.168.0.0/24") + err = app.db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) } func (s *Suite) TestGetEnableRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -88,54 +89,54 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.processMachineRoutes(&machine) + err = app.db.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - availableRoutes, err := app.GetAdvertisedRoutes(&machine) + availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(err, check.IsNil) c.Assert(len(availableRoutes), check.Equals, 2) - noEnabledRoutes, err := app.GetEnabledRoutes(&machine) + noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = app.enableRoutes(&machine, "192.168.0.0/24") + err = app.db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes, err := app.GetEnabledRoutes(&machine) + enabledRoutes, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = app.enableRoutes(&machine, "10.0.0.0/24") + err = app.db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enableRoutesAfterDoubleApply, err := app.GetEnabledRoutes(&machine) + enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = app.enableRoutes(&machine, "150.0.10.0/25") + err = app.db.enableRoutes(&machine, "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutesWithAdditionalRoute, err := app.GetEnabledRoutes(&machine) + enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) } func (s *Suite) TestIsUniquePrefix(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -162,15 +163,15 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo1), } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, route.String()) + err = app.db.enableRoutes(&machine1, route.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, route2.String()) + err = app.db.enableRoutes(&machine1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ @@ -187,39 +188,39 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), HostInfo: HostInfo(hostInfo2), } - app.db.Save(&machine2) + app.db.db.Save(&machine2) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, route2.String()) + err = app.db.enableRoutes(&machine2, route2.String()) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.getMachinePrimaryRoutes(&machine1) + routes, err := app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) } func (s *Suite) TestSubnetFailover(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -249,25 +250,25 @@ func (s *Suite) TestSubnetFailover(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix2.String()) + err = app.db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - route, err := app.getPrimaryRoute(prefix) + route, err := app.db.getPrimaryRoute(prefix) c.Assert(err, check.IsNil) c.Assert(route.MachineID, check.Equals, machine1.ID) @@ -286,70 +287,70 @@ func (s *Suite) TestSubnetFailover(c *check.C) { HostInfo: HostInfo(hostInfo2), LastSeen: &now, } - app.db.Save(&machine2) + app.db.db.Save(&machine2) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, prefix2.String()) + err = app.db.enableRoutes(&machine2, prefix2.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err = app.GetEnabledRoutes(&machine1) + enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.getMachinePrimaryRoutes(&machine1) + routes, err := app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) // lets make machine1 lastseen 10 mins ago before := now.Add(-10 * time.Minute) machine1.LastSeen = &before - err = app.db.Save(&machine1).Error + err = app.db.db.Save(&machine1).Error c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.getMachinePrimaryRoutes(&machine1) + routes, err = app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix, prefix2}, }) - err = app.db.Save(&machine2).Error + err = app.db.db.Save(&machine2).Error c.Assert(err, check.IsNil) - err = app.processMachineRoutes(&machine2) + err = app.db.processMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine2, prefix.String()) + err = app.db.enableRoutes(&machine2, prefix.String()) c.Assert(err, check.IsNil) - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.getMachinePrimaryRoutes(&machine1) + routes, err = app.db.getMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) - routes, err = app.getMachinePrimaryRoutes(&machine2) + routes, err = app.db.getMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) } @@ -358,13 +359,13 @@ func (s *Suite) TestSubnetFailover(c *check.C) { // including both the primary routes the node is responsible for, and the // exit node routes if enabled. func (s *Suite) TestAllowedIPRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -398,9 +399,9 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { now := time.Now() machine1 := Machine{ ID: 1, - MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: DiscoPublicKeyStripPrefix(discoKey.Public()), + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), Hostname: "test_enable_route_machine", UserID: user.ID, RegisterMethod: RegisterMethodAuthKey, @@ -408,23 +409,23 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) // We do not enable this one on purpose to test that it is not enabled - // err = app.enableRoutes(&machine1, prefix2.String()) + // err = app.db.enableRoutes(&machine1, prefix2.String()) // c.Assert(err, check.IsNil) - routes, err := app.GetMachineRoutes(&machine1) + routes, err := app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) for _, route := range routes { if route.isExitRoute() { - err = app.EnableRoute(uint64(route.ID)) + err = app.db.EnableRoute(uint64(route.ID)) c.Assert(err, check.IsNil) // We only enable one exit route, so we can test that both are enabled @@ -432,14 +433,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.handlePrimarySubnetFailover() + err = app.db.handlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 3) - peer, err := app.toNode(machine1, "headscale.net", nil) + peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) c.Assert(err, check.IsNil) c.Assert(len(peer.AllowedIPs), check.Equals, 3) @@ -469,35 +470,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.DisableRoute(uint64(exitRouteV4.ID)) + err = app.db.DisableRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err = app.GetEnabledRoutes(&machine1) + enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) // and now we delete only one of the exit routes // and we check if both are deleted - routes, err = app.GetMachineRoutes(&machine1) + routes, err = app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 4) - err = app.DeleteRoute(uint64(exitRouteV4.ID)) + err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - routes, err = app.GetMachineRoutes(&machine1) + routes, err = app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) } func (s *Suite) TestDeleteRoutes(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.GetMachine("test", "test_enable_route_machine") + _, err = app.db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -527,24 +528,24 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { HostInfo: HostInfo(hostInfo1), LastSeen: &now, } - app.db.Save(&machine1) + app.db.db.Save(&machine1) - err = app.processMachineRoutes(&machine1) + err = app.db.processMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix.String()) + err = app.db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.enableRoutes(&machine1, prefix2.String()) + err = app.db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - routes, err := app.GetMachineRoutes(&machine1) + routes, err := app.db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.DeleteRoute(uint64(routes[0].ID)) + err = app.db.DeleteRoute(uint64(routes[0].ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) } diff --git a/hscontrol/users.go b/hscontrol/users.go index 8782a8908b..fb3cea9c15 100644 --- a/hscontrol/users.go +++ b/hscontrol/users.go @@ -9,17 +9,18 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" ) -const ( - ErrUserExists = Error("User already exists") - ErrUserNotFound = Error("User not found") - ErrUserStillHasNodes = Error("User not empty: node(s) found") - ErrInvalidUserName = Error("Invalid user name") +var ( + ErrUserExists = errors.New("user already exists") + ErrUserNotFound = errors.New("user not found") + ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrInvalidUserName = errors.New("invalid user name") ) const ( @@ -40,17 +41,17 @@ type User struct { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (h *Headscale) CreateUser(name string) (*User, error) { +func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { err := CheckForFQDNRules(name) if err != nil { return nil, err } user := User{} - if err := h.db.Where("name = ?", name).First(&user).Error; err == nil { + if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } user.Name = name - if err := h.db.Create(&user).Error; err != nil { + if err := hsdb.db.Create(&user).Error; err != nil { log.Error(). Str("func", "CreateUser"). Err(err). @@ -64,13 +65,13 @@ func (h *Headscale) CreateUser(name string) (*User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. -func (h *Headscale) DestroyUser(name string) error { - user, err := h.GetUser(name) +func (hsdb *HSDatabase) DestroyUser(name string) error { + user, err := hsdb.GetUser(name) if err != nil { return ErrUserNotFound } - machines, err := h.ListMachinesByUser(name) + machines, err := hsdb.ListMachinesByUser(name) if err != nil { return err } @@ -78,18 +79,18 @@ func (h *Headscale) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := h.ListPreAuthKeys(name) + keys, err := hsdb.ListPreAuthKeys(name) if err != nil { return err } for _, key := range keys { - err = h.DestroyPreAuthKey(key) + err = hsdb.DestroyPreAuthKey(key) if err != nil { return err } } - if result := h.db.Unscoped().Delete(&user); result.Error != nil { + if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { return result.Error } @@ -98,9 +99,9 @@ func (h *Headscale) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (h *Headscale) RenameUser(oldName, newName string) error { +func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { var err error - oldUser, err := h.GetUser(oldName) + oldUser, err := hsdb.GetUser(oldName) if err != nil { return err } @@ -108,7 +109,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = h.GetUser(newName) + _, err = hsdb.GetUser(newName) if err == nil { return ErrUserExists } @@ -118,7 +119,7 @@ func (h *Headscale) RenameUser(oldName, newName string) error { oldUser.Name = newName - if result := h.db.Save(&oldUser); result.Error != nil { + if result := hsdb.db.Save(&oldUser); result.Error != nil { return result.Error } @@ -126,9 +127,9 @@ func (h *Headscale) RenameUser(oldName, newName string) error { } // GetUser fetches a user by name. -func (h *Headscale) GetUser(name string) (*User, error) { +func (hsdb *HSDatabase) GetUser(name string) (*User, error) { user := User{} - if result := h.db.First(&user, "name = ?", name); errors.Is( + if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -139,9 +140,9 @@ func (h *Headscale) GetUser(name string) (*User, error) { } // ListUsers gets all the existing users. -func (h *Headscale) ListUsers() ([]User, error) { +func (hsdb *HSDatabase) ListUsers() ([]User, error) { users := []User{} - if err := h.db.Find(&users).Error; err != nil { + if err := hsdb.db.Find(&users).Error; err != nil { return nil, err } @@ -149,18 +150,18 @@ func (h *Headscale) ListUsers() ([]User, error) { } // ListMachinesByUser gets all the nodes in a given user. -func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) { +func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { err := CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := h.GetUser(name) + user, err := hsdb.GetUser(name) if err != nil { return nil, err } machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { return nil, err } @@ -168,17 +169,17 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) { } // SetMachineUser assigns a Machine to a user. -func (h *Headscale) SetMachineUser(machine *Machine, username string) error { +func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { err := CheckForFQDNRules(username) if err != nil { return err } - user, err := h.GetUser(username) + user, err := hsdb.GetUser(username) if err != nil { return err } machine.User = *user - if result := h.db.Save(&machine); result.Error != nil { + if result := hsdb.db.Save(&machine); result.Error != nil { return result.Error } @@ -211,7 +212,7 @@ func (n *User) toTailscaleLogin() *tailcfg.Login { return &login } -func (h *Headscale) getMapResponseUserProfiles( +func (hsdb *HSDatabase) getMapResponseUserProfiles( machine Machine, peers Machines, ) []tailcfg.UserProfile { @@ -225,8 +226,8 @@ func (h *Headscale) getMapResponseUserProfiles( for _, user := range userMap { displayName := user.Name - if h.cfg.BaseDomain != "" { - displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain) + if hsdb.baseDomain != "" { + displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain) } profiles = append(profiles, @@ -242,7 +243,7 @@ func (h *Headscale) getMapResponseUserProfiles( func (n *User) toProto() *v1.User { return &v1.User{ - Id: strconv.FormatUint(uint64(n.ID), Base10), + Id: strconv.FormatUint(uint64(n.ID), util.Base10), Name: n.Name, CreatedAt: timestamppb.New(n.CreatedAt), } diff --git a/hscontrol/users_test.go b/hscontrol/users_test.go index 12aa9880d4..1d68f92fb4 100644 --- a/hscontrol/users_test.go +++ b/hscontrol/users_test.go @@ -9,42 +9,42 @@ import ( ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) c.Assert(user.Name, check.Equals, "test") - users, err := app.ListUsers() + users, err := app.db.ListUsers() c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.IsNil) - _, err = app.GetUser("test") + _, err = app.db.GetUser("test") c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := app.DestroyUser("test") + err := app.db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserNotFound) - user, err := app.CreateUser("test") + user, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.IsNil) - result := app.db.Preload("User").First(&pak, "key = ?", pak.Key) + result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) // destroying a user also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) - user, err = app.CreateUser("test") + user, err = app.db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = app.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -57,52 +57,52 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) - err = app.DestroyUser("test") + err = app.db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) } func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := app.CreateUser("test") + userTest, err := app.db.CreateUser("test") c.Assert(err, check.IsNil) c.Assert(userTest.Name, check.Equals, "test") - users, err := app.ListUsers() + users, err := app.db.ListUsers() c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = app.RenameUser("test", "test-renamed") + err = app.db.RenameUser("test", "test-renamed") c.Assert(err, check.IsNil) - _, err = app.GetUser("test") + _, err = app.db.GetUser("test") c.Assert(err, check.Equals, ErrUserNotFound) - _, err = app.GetUser("test-renamed") + _, err = app.db.GetUser("test-renamed") c.Assert(err, check.IsNil) - err = app.RenameUser("test-does-not-exit", "test") + err = app.db.RenameUser("test-does-not-exit", "test") c.Assert(err, check.Equals, ErrUserNotFound) - userTest2, err := app.CreateUser("test2") + userTest2, err := app.db.CreateUser("test2") c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") - err = app.RenameUser("test2", "test-renamed") + err = app.db.RenameUser("test2", "test-renamed") c.Assert(err, check.Equals, ErrUserExists) } func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - userShared1, err := app.CreateUser("shared1") + userShared1, err := app.db.CreateUser("shared1") c.Assert(err, check.IsNil) - userShared2, err := app.CreateUser("shared2") + userShared2, err := app.db.CreateUser("shared2") c.Assert(err, check.IsNil) - userShared3, err := app.CreateUser("shared3") + userShared3, err := app.db.CreateUser("shared3") c.Assert(err, check.IsNil) - preAuthKeyShared1, err := app.CreatePreAuthKey( + preAuthKeyShared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -111,7 +111,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyShared2, err := app.CreatePreAuthKey( + preAuthKeyShared2, err := app.db.CreatePreAuthKey( userShared2.Name, false, false, @@ -120,7 +120,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKeyShared3, err := app.CreatePreAuthKey( + preAuthKeyShared3, err := app.db.CreatePreAuthKey( userShared3.Name, false, false, @@ -129,7 +129,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - preAuthKey2Shared1, err := app.CreatePreAuthKey( + preAuthKey2Shared1, err := app.db.CreatePreAuthKey( userShared1.Name, false, false, @@ -138,7 +138,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { ) c.Assert(err, check.IsNil) - _, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) machineInShared1 := &Machine{ @@ -153,9 +153,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyShared1.ID), } - app.db.Save(machineInShared1) + app.db.db.Save(machineInShared1) - _, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname) + _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) machineInShared2 := &Machine{ @@ -170,9 +170,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyShared2.ID), } - app.db.Save(machineInShared2) + app.db.db.Save(machineInShared2) - _, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname) + _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) machineInShared3 := &Machine{ @@ -187,9 +187,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyShared3.ID), } - app.db.Save(machineInShared3) + app.db.db.Save(machineInShared3) - _, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname) + _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) machine2InShared1 := &Machine{ @@ -204,12 +204,12 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2Shared1.ID), } - app.db.Save(machine2InShared1) + app.db.db.Save(machine2InShared1) - peersOfMachine1InShared1, err := app.getPeers(machineInShared1) + peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) c.Assert(err, check.IsNil) - userProfiles := app.getMapResponseUserProfiles( + userProfiles := app.db.getMapResponseUserProfiles( *machineInShared1, peersOfMachine1InShared1, ) @@ -378,13 +378,13 @@ func TestCheckForFQDNRules(t *testing.T) { } func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := app.CreateUser("old") + oldUser, err := app.db.CreateUser("old") c.Assert(err, check.IsNil) - newUser, err := app.CreateUser("new") + newUser, err := app.db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := app.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) machine := Machine{ @@ -397,18 +397,18 @@ func (s *Suite) TestSetMachineUser(c *check.C) { RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.Save(&machine) + app.db.db.Save(&machine) c.Assert(machine.UserID, check.Equals, oldUser.ID) - err = app.SetMachineUser(&machine, newUser.Name) + err = app.db.SetMachineUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) - err = app.SetMachineUser(&machine, "non-existing-user") + err = app.db.SetMachineUser(&machine, "non-existing-user") c.Assert(err, check.Equals, ErrUserNotFound) - err = app.SetMachineUser(&machine, newUser.Name) + err = app.db.SetMachineUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go new file mode 100644 index 0000000000..d312a6e04d --- /dev/null +++ b/hscontrol/util/addr.go @@ -0,0 +1,42 @@ +package util + +import ( + "net/netip" + "reflect" + + "go4.org/netipx" +) + +func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { + var network, broadcast netip.Addr + ipRange := netipx.RangeOfPrefix(na) + network = ipRange.From() + broadcast = ipRange.To() + + return network, broadcast +} + +func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { + result := make([]netip.Prefix, len(prefixes)) + + for index, prefixStr := range prefixes { + prefix, err := netip.ParsePrefix(prefixStr) + if err != nil { + return []netip.Prefix{}, err + } + + result[index] = prefix + } + + return result, nil +} + +func StringOrPrefixListContains[T string | netip.Prefix](ts []T, t T) bool { + for _, v := range ts { + if reflect.DeepEqual(v, t) { + return true + } + } + + return false +} diff --git a/hscontrol/util/file.go b/hscontrol/util/file.go new file mode 100644 index 0000000000..7b424da768 --- /dev/null +++ b/hscontrol/util/file.go @@ -0,0 +1,43 @@ +package util + +import ( + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/spf13/viper" +) + +const ( + Base8 = 8 + Base10 = 10 + BitSize16 = 16 + BitSize32 = 32 + BitSize64 = 64 +) + +func AbsolutePathFromConfigPath(path string) string { + // If a relative path is provided, prefix it with the directory where + // the config file was found. + if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { + dir, _ := filepath.Split(viper.ConfigFileUsed()) + if dir != "" { + path = filepath.Join(dir, path) + } + } + + return path +} + +func GetFileMode(key string) fs.FileMode { + modeStr := viper.GetString(key) + + mode, err := strconv.ParseUint(modeStr, Base8, BitSize64) + if err != nil { + return PermissionFallback + } + + return fs.FileMode(mode) +} diff --git a/hscontrol/util/key.go b/hscontrol/util/key.go new file mode 100644 index 0000000000..4eb1db6c08 --- /dev/null +++ b/hscontrol/util/key.go @@ -0,0 +1,117 @@ +package util + +import ( + "encoding/json" + "errors" + "regexp" + "strings" + + "tailscale.com/types/key" +) + +const ( + + // These constants are copied from the upstream tailscale.com/types/key + // library, because they are not exported. + // https://github.com/tailscale/tailscale/tree/main/types/key + + // nodePublicHexPrefix is the prefix used to identify a + // hex-encoded node public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + nodePublicHexPrefix = "nodekey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" + + // discoPublicHexPrefix is the prefix used to identify a + // hex-encoded disco public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + discoPublicHexPrefix = "discokey:" + + // privateKey prefix. + privateHexPrefix = "privkey:" + + PermissionFallback = 0o700 + + ZstdCompression = "zstd" +) + +var ( + NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") + ErrCannotDecryptResponse = errors.New("cannot decrypt response") +) + +func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { + return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) +} + +func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { + return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) +} + +func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { + return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) +} + +func MachinePublicKeyEnsurePrefix(machineKey string) string { + if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { + return machinePublicHexPrefix + machineKey + } + + return machineKey +} + +func NodePublicKeyEnsurePrefix(nodeKey string) string { + if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { + return nodePublicHexPrefix + nodeKey + } + + return nodeKey +} + +func DiscoPublicKeyEnsurePrefix(discoKey string) string { + if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { + return discoPublicHexPrefix + discoKey + } + + return discoKey +} + +func PrivateKeyEnsurePrefix(privateKey string) string { + if !strings.HasPrefix(privateKey, privateHexPrefix) { + return privateHexPrefix + privateKey + } + + return privateKey +} + +func DecodeAndUnmarshalNaCl( + msg []byte, + output interface{}, + pubKey *key.MachinePublic, + privKey *key.MachinePrivate, +) error { + // log.Trace(). + // Str("pubkey", pubKey.ShortString()). + // Int("length", len(msg)). + // Msg("Trying to decrypt") + + decrypted, ok := privKey.OpenFrom(*pubKey, msg) + if !ok { + return ErrCannotDecryptResponse + } + + if err := json.Unmarshal(decrypted, output); err != nil { + return err + } + + return nil +} diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go new file mode 100644 index 0000000000..b704c936c0 --- /dev/null +++ b/hscontrol/util/net.go @@ -0,0 +1,12 @@ +package util + +import ( + "context" + "net" +) + +func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + + return d.DialContext(ctx, "unix", addr) +} diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go new file mode 100644 index 0000000000..6f018affdb --- /dev/null +++ b/hscontrol/util/string.go @@ -0,0 +1,85 @@ +package util + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "strings" + + "tailscale.com/tailcfg" +) + +// GenerateRandomBytes returns securely generated random bytes. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomBytes(n int) ([]byte, error) { + bytes := make([]byte, n) + + // Note that err == nil only if we read len(b) bytes. + if _, err := rand.Read(bytes); err != nil { + return nil, err + } + + return bytes, nil +} + +// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded +// securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomStringURLSafe(n int) (string, error) { + b, err := GenerateRandomBytes(n) + + return base64.RawURLEncoding.EncodeToString(b), err +} + +// GenerateRandomStringDNSSafe returns a DNS-safe +// securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomStringDNSSafe(size int) (string, error) { + var str string + var err error + for len(str) < size { + str, err = GenerateRandomStringURLSafe(size) + if err != nil { + return "", err + } + str = strings.ToLower( + strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""), + ) + } + + return str[:size], nil +} + +func IsStringInSlice(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + + return false +} + +func TailNodesToString(nodes []*tailcfg.Node) string { + temp := make([]string, len(nodes)) + + for index, node := range nodes { + temp[index] = node.Name + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +func TailMapResponseToString(resp tailcfg.MapResponse) string { + return fmt.Sprintf( + "{ Node: %s, Peers: %s }", + resp.Node.Name, + TailNodesToString(resp.Peers), + ) +} diff --git a/hscontrol/util/string_test.go b/hscontrol/util/string_test.go new file mode 100644 index 0000000000..87a8be1c0b --- /dev/null +++ b/hscontrol/util/string_test.go @@ -0,0 +1,15 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateRandomStringDNSSafe(t *testing.T) { + for i := 0; i < 100000; i++ { + str, err := GenerateRandomStringDNSSafe(8) + assert.Nil(t, err) + assert.Len(t, str, 8) + } +} diff --git a/hscontrol/utils.go b/hscontrol/utils.go deleted file mode 100644 index 9cfbf0cab9..0000000000 --- a/hscontrol/utils.go +++ /dev/null @@ -1,361 +0,0 @@ -// Codehere is mostly taken from github.com/tailscale/tailscale -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package hscontrol - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io/fs" - "net" - "net/netip" - "os" - "path/filepath" - "reflect" - "regexp" - "strconv" - "strings" - - "github.com/rs/zerolog/log" - "github.com/spf13/viper" - "go4.org/netipx" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -const ( - ErrCannotDecryptResponse = Error("cannot decrypt response") - ErrCouldNotAllocateIP = Error("could not find any suitable IP") - - // These constants are copied from the upstream tailscale.com/types/key - // library, because they are not exported. - // https://github.com/tailscale/tailscale/tree/main/types/key - - // nodePublicHexPrefix is the prefix used to identify a - // hex-encoded node public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - nodePublicHexPrefix = "nodekey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" - - // discoPublicHexPrefix is the prefix used to identify a - // hex-encoded disco public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - discoPublicHexPrefix = "discokey:" - - // privateKey prefix. - privateHexPrefix = "privkey:" - - PermissionFallback = 0o700 - - ZstdCompression = "zstd" -) - -var NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") - -func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string { - return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix) -} - -func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string { - return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix) -} - -func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string { - return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) -} - -func MachinePublicKeyEnsurePrefix(machineKey string) string { - if !strings.HasPrefix(machineKey, machinePublicHexPrefix) { - return machinePublicHexPrefix + machineKey - } - - return machineKey -} - -func NodePublicKeyEnsurePrefix(nodeKey string) string { - if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) { - return nodePublicHexPrefix + nodeKey - } - - return nodeKey -} - -func DiscoPublicKeyEnsurePrefix(discoKey string) string { - if !strings.HasPrefix(discoKey, discoPublicHexPrefix) { - return discoPublicHexPrefix + discoKey - } - - return discoKey -} - -func PrivateKeyEnsurePrefix(privateKey string) string { - if !strings.HasPrefix(privateKey, privateHexPrefix) { - return privateHexPrefix + privateKey - } - - return privateKey -} - -// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors -type Error string - -func (e Error) Error() string { return string(e) } - -func decode( - msg []byte, - output interface{}, - pubKey *key.MachinePublic, - privKey *key.MachinePrivate, -) error { - log.Trace(). - Str("pubkey", pubKey.ShortString()). - Int("length", len(msg)). - Msg("Trying to decrypt") - - decrypted, ok := privKey.OpenFrom(*pubKey, msg) - if !ok { - return ErrCannotDecryptResponse - } - - if err := json.Unmarshal(decrypted, output); err != nil { - return err - } - - return nil -} - -func (h *Headscale) getAvailableIPs() (MachineAddresses, error) { - var ips MachineAddresses - var err error - ipPrefixes := h.cfg.IPPrefixes - for _, ipPrefix := range ipPrefixes { - var ip *netip.Addr - ip, err = h.getAvailableIP(ipPrefix) - if err != nil { - return ips, err - } - ips = append(ips, *ip) - } - - return ips, err -} - -func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { - var network, broadcast netip.Addr - ipRange := netipx.RangeOfPrefix(na) - network = ipRange.From() - broadcast = ipRange.To() - - return network, broadcast -} - -func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := h.getUsedIPs() - if err != nil { - return nil, err - } - - ipPrefixNetworkAddress, ipPrefixBroadcastAddress := GetIPPrefixEndpoints(ipPrefix) - - // Get the first IP in our prefix - ip := ipPrefixNetworkAddress.Next() - - for { - if !ipPrefix.Contains(ip) { - return nil, ErrCouldNotAllocateIP - } - - switch { - case ip.Compare(ipPrefixBroadcastAddress) == 0: - fallthrough - case usedIps.Contains(ip): - fallthrough - case ip == netip.Addr{} || ip.IsLoopback(): - ip = ip.Next() - - continue - - default: - return &ip, nil - } - } -} - -func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) { - // FIXME: This really deserves a better data model, - // but this was quick to get running and it should be enough - // to begin experimenting with a dual stack tailnet. - var addressesSlices []string - h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) - - var ips netipx.IPSetBuilder - for _, slice := range addressesSlices { - var machineAddresses MachineAddresses - err := machineAddresses.Scan(slice) - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to read ip from database: %w", - err, - ) - } - - for _, ip := range machineAddresses { - ips.Add(ip) - } - } - - ipSet, err := ips.IPSet() - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to build IP Set: %w", - err, - ) - } - - return ipSet, nil -} - -func tailNodesToString(nodes []*tailcfg.Node) string { - temp := make([]string, len(nodes)) - - for index, node := range nodes { - temp[index] = node.Name - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -func tailMapResponseToString(resp tailcfg.MapResponse) string { - return fmt.Sprintf( - "{ Node: %s, Peers: %s }", - resp.Node.Name, - tailNodesToString(resp.Peers), - ) -} - -func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { - var d net.Dialer - - return d.DialContext(ctx, "unix", addr) -} - -func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { - result := make([]netip.Prefix, len(prefixes)) - - for index, prefixStr := range prefixes { - prefix, err := netip.ParsePrefix(prefixStr) - if err != nil { - return []netip.Prefix{}, err - } - - result[index] = prefix - } - - return result, nil -} - -func contains[T string | netip.Prefix](ts []T, t T) bool { - for _, v := range ts { - if reflect.DeepEqual(v, t) { - return true - } - } - - return false -} - -// GenerateRandomBytes returns securely generated random bytes. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomBytes(n int) ([]byte, error) { - bytes := make([]byte, n) - - // Note that err == nil only if we read len(b) bytes. - if _, err := rand.Read(bytes); err != nil { - return nil, err - } - - return bytes, nil -} - -// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded -// securely generated random string. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomStringURLSafe(n int) (string, error) { - b, err := GenerateRandomBytes(n) - - return base64.RawURLEncoding.EncodeToString(b), err -} - -// GenerateRandomStringDNSSafe returns a DNS-safe -// securely generated random string. -// It will return an error if the system's secure random -// number generator fails to function correctly, in which -// case the caller should not continue. -func GenerateRandomStringDNSSafe(size int) (string, error) { - var str string - var err error - for len(str) < size { - str, err = GenerateRandomStringURLSafe(size) - if err != nil { - return "", err - } - str = strings.ToLower( - strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""), - ) - } - - return str[:size], nil -} - -func IsStringInSlice(slice []string, str string) bool { - for _, s := range slice { - if s == str { - return true - } - } - - return false -} - -func AbsolutePathFromConfigPath(path string) string { - // If a relative path is provided, prefix it with the directory where - // the config file was found. - if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { - dir, _ := filepath.Split(viper.ConfigFileUsed()) - if dir != "" { - path = filepath.Join(dir, path) - } - } - - return path -} - -func GetFileMode(key string) fs.FileMode { - modeStr := viper.GetString(key) - - mode, err := strconv.ParseUint(modeStr, Base8, BitSize64) - if err != nil { - return PermissionFallback - } - - return fs.FileMode(mode) -} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 8ad8f329c5..452f852044 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/ory/dockertest/v3" @@ -220,7 +221,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDC } portNotation := fmt.Sprintf("%d/tcp", port) - hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := fmt.Sprintf("hs-oidcmock-%s", hash) diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index be128087d5..e9183cdcbd 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -6,7 +6,7 @@ import ( "net/url" "testing" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -110,7 +110,7 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( return err } - hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength) + hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) if err != nil { return err } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 6b1652b084..0051b40013 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -24,6 +24,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" @@ -132,7 +133,7 @@ func WithHostPortBindings(bindings map[string][]string) Option { // in the Docker container name. func WithTestName(testName string) Option { return func(hsic *HeadscaleInContainer) { - hash, _ := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength) + hash, _ := util.GenerateRandomStringDNSSafe(hsicHashLength) hostname := fmt.Sprintf("hs-%s-%s", testName, hash) hsic.hostname = hostname @@ -167,7 +168,7 @@ func New( network *dockertest.Network, opts ...Option, ) (*HeadscaleInContainer, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(hsicHashLength) + hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) if err != nil { return nil, err } diff --git a/integration/scenario.go b/integration/scenario.go index 58005482a6..927d6c80cb 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -10,7 +10,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -105,7 +105,7 @@ type Scenario struct { // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // a set of Users and TailscaleClients. func NewScenario() (*Scenario, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(scenarioHashLength) + hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) if err != nil { return nil, err } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index cc285f3bb0..ffc7e0a90e 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -12,7 +12,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" @@ -150,7 +150,7 @@ func New( network *dockertest.Network, opts ...Option, ) (*TailscaleInContainer, error) { - hash, err := hscontrol.GenerateRandomStringDNSSafe(tsicHashLength) + hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) if err != nil { return nil, err } From 8c4e46aaf73b46f3e88f9ad7ab441545affb9d61 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 21 May 2023 19:37:59 +0300 Subject: [PATCH 2/2] Split code into modules This is a massive commit that restructures the code into modules: db/ All functions related to modifying the Database types/ All type definitions and methods that can be exclusivly used on these types without dependencies policy/ All Policy related code, now without dependencies on the Database. policy/matcher/ Dedicated code to match machines in a list of FilterRules Signed-off-by: Kristoffer Dalby --- cmd/headscale/cli/routes.go | 4 +- cmd/headscale/cli/utils.go | 5 +- hscontrol/api.go | 5 +- hscontrol/api_common.go | 11 +- hscontrol/app.go | 166 +- hscontrol/db/acls_test.go | 480 ++++++ hscontrol/{ => db}/addresses.go | 11 +- hscontrol/{ => db}/addresses_test.go | 74 +- hscontrol/{ => db}/api_key.go | 60 +- hscontrol/{ => db}/api_key_test.go | 42 +- hscontrol/{ => db}/db.go | 167 +- hscontrol/{ => db}/machine.go | 966 ++++------- hscontrol/db/machine_test.go | 797 +++++++++ hscontrol/{ => db}/preauth_keys.go | 89 +- hscontrol/{ => db}/preauth_keys_test.go | 106 +- hscontrol/{ => db}/routes.go | 214 ++- hscontrol/{ => db}/routes_test.go | 254 +-- hscontrol/db/suite_test.go | 74 + hscontrol/{ => db}/users.go | 148 +- hscontrol/db/users_test.go | 277 ++++ hscontrol/dns.go | 9 +- hscontrol/dns_test.go | 62 +- hscontrol/grpcv1.go | 50 +- hscontrol/machine_test.go | 1386 ---------------- hscontrol/matcher.go | 142 -- hscontrol/oidc.go | 12 +- hscontrol/{ => policy}/acls.go | 230 ++- hscontrol/{ => policy}/acls_test.go | 1464 ++++++++++------- hscontrol/{ => policy}/acls_types.go | 2 +- hscontrol/policy/matcher/matcher.go | 61 + hscontrol/policy/matcher/matcher_test.go | 1 + hscontrol/protocol_common.go | 62 +- hscontrol/protocol_common_poll.go | 50 +- hscontrol/protocol_common_utils.go | 5 +- hscontrol/{app_test.go => suite_test.go} | 31 +- hscontrol/types/api_key.go | 41 + hscontrol/types/common.go | 108 ++ hscontrol/types/machine.go | 254 +++ hscontrol/types/machine_test.go | 1 + hscontrol/types/preauth_key.go | 58 + hscontrol/types/routes.go | 71 + hscontrol/types/users.go | 55 + hscontrol/users_test.go | 415 ----- hscontrol/util/addr.go | 82 + .../{matcher_test.go => util/addr_test.go} | 4 +- hscontrol/util/const.go | 7 + hscontrol/util/dns.go | 69 + hscontrol/util/dns_test.go | 143 ++ integration/acl_test.go | 88 +- integration/hsic/hsic.go | 6 +- integration/ssh_test.go | 32 +- 51 files changed, 4669 insertions(+), 4282 deletions(-) create mode 100644 hscontrol/db/acls_test.go rename hscontrol/{ => db}/addresses.go (87%) rename hscontrol/{ => db}/addresses_test.go (71%) rename hscontrol/{ => db}/api_key.go (64%) rename hscontrol/{ => db}/api_key_test.go (63%) rename hscontrol/{ => db}/db.go (63%) rename hscontrol/{ => db}/machine.go (58%) create mode 100644 hscontrol/db/machine_test.go rename hscontrol/{ => db}/preauth_keys.go (59%) rename hscontrol/{ => db}/preauth_keys_test.go (57%) rename hscontrol/{ => db}/routes.go (62%) rename hscontrol/{ => db}/routes_test.go (60%) create mode 100644 hscontrol/db/suite_test.go rename hscontrol/{ => db}/users.go (50%) create mode 100644 hscontrol/db/users_test.go delete mode 100644 hscontrol/machine_test.go delete mode 100644 hscontrol/matcher.go rename hscontrol/{ => policy}/acls.go (79%) rename hscontrol/{ => policy}/acls_test.go (56%) rename hscontrol/{ => policy}/acls_types.go (99%) create mode 100644 hscontrol/policy/matcher/matcher.go create mode 100644 hscontrol/policy/matcher/matcher_test.go rename hscontrol/{app_test.go => suite_test.go} (54%) create mode 100644 hscontrol/types/api_key.go create mode 100644 hscontrol/types/common.go create mode 100644 hscontrol/types/machine.go create mode 100644 hscontrol/types/machine_test.go create mode 100644 hscontrol/types/preauth_key.go create mode 100644 hscontrol/types/routes.go create mode 100644 hscontrol/types/users.go delete mode 100644 hscontrol/users_test.go rename hscontrol/{matcher_test.go => util/addr_test.go} (96%) create mode 100644 hscontrol/util/const.go create mode 100644 hscontrol/util/dns.go create mode 100644 hscontrol/util/dns_test.go diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 206209d976..90dd51a852 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -7,7 +7,7 @@ import ( "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/types" "github.com/pterm/pterm" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData { continue } - if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 { + if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 { isPrimaryStr = "-" } else { isPrimaryStr = strconv.FormatBool(route.IsPrimary) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 2831dbf775..5ce7816f49 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,6 +10,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc" @@ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { if cfg.ACL.PolicyPath != "" { aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) - err = app.LoadACLPolicyFromPath(aclPath) + pol, err := policy.LoadACLPolicyFromPath(aclPath) if err != nil { log.Fatal(). Str("path", aclPath). Err(err). Msg("Could not load the ACL policy") } + + app.ACLPolicy = pol } return app, nil diff --git a/hscontrol/api.go b/hscontrol/api.go index 8e3014199a..4a43aeb9f5 100644 --- a/hscontrol/api.go +++ b/hscontrol/api.go @@ -18,9 +18,6 @@ const ( // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 - RegisterMethodAuthKey = "authkey" - RegisterMethodOIDC = "oidc" - RegisterMethodCLI = "cli" ) var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( @@ -56,7 +53,7 @@ func (h *Headscale) HealthHandler( } } - if err := h.db.pingDB(req.Context()); err != nil { + if err := h.db.PingDB(req.Context()); err != nil { respond(err) return diff --git a/hscontrol/api_common.go b/hscontrol/api_common.go index f1b3fd8300..4d40c1d1b7 100644 --- a/hscontrol/api_common.go +++ b/hscontrol/api_common.go @@ -3,6 +3,7 @@ package hscontrol import ( "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" @@ -10,13 +11,13 @@ import ( func (h *Headscale) generateMapResponse( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, ) (*tailcfg.MapResponse, error) { log.Trace(). Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse( return nil, err } - peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine) + peers, err := h.db.GetValidPeers(h.aclRules, machine) if err != nil { log.Error(). Caller(). @@ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse( return nil, err } - profiles := h.db.getMapResponseUserProfiles(*machine, peers) + profiles := h.db.GetMapResponseUserProfiles(*machine, peers) - nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/app.go b/hscontrol/app.go index 38d4ec8cca..bb68ced7af 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -23,6 +23,9 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" @@ -73,7 +76,7 @@ const ( // Headscale represents the base app of the service. type Headscale struct { cfg *Config - db *HSDatabase + db *db.HSDatabase dbString string dbType string dbDebug bool @@ -83,7 +86,7 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *DERPServer - aclPolicy *ACLPolicy + ACLPolicy *policy.ACLPolicy aclRules []tailcfg.FilterRule sshPolicy *tailcfg.SSHPolicy @@ -99,6 +102,12 @@ type Headscale struct { stateUpdateChan chan struct{} cancelStateUpdateChan chan struct{} + + // TODO(kradalby): Temporary measure to make sure we can update policy + // across modules, will be removed when aclRules are no longer stored + // globally but generated per node basis. + policyUpdateChan chan struct{} + cancelPolicyUpdateChan chan struct{} } func NewHeadscale(cfg *Config) (*Headscale, error) { @@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { var dbString string switch cfg.DBtype { - case Postgres: + case db.Postgres: dbString = fmt.Sprintf( "host=%s dbname=%s user=%s", cfg.DBhost, @@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { if cfg.DBpass != "" { dbString += fmt.Sprintf(" password=%s", cfg.DBpass) } - case Sqlite: + case db.Sqlite: dbString = cfg.DBpath default: return nil, errUnsupportedDatabase @@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { stateUpdateChan: make(chan struct{}), cancelStateUpdateChan: make(chan struct{}), + + policyUpdateChan: make(chan struct{}), + cancelPolicyUpdateChan: make(chan struct{}), } go app.watchStateChannel() + go app.watchPolicyChannel() - db, err := NewHeadscaleDatabase( + database, err := db.NewHeadscaleDatabase( cfg.DBtype, dbString, cfg.OIDC.StripEmaildomain, app.dbDebug, app.stateUpdateChan, + app.policyUpdateChan, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { return nil, err } - app.db = db + app.db = database if cfg.OIDC.Issuer != "" { err = app.initOIDC() @@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - h.expireEphemeralNodesWorker() + h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout) } } @@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) expireExpiredMachines(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - h.expireExpiredMachinesWorker() + h.db.ExpireExpiredMachines(h.getLastStateChange()) } } func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) for range ticker.C { - err := h.db.handlePrimarySubnetFailover() + err := h.db.HandlePrimarySubnetFailover() if err != nil { log.Error().Err(err).Msg("failed to handle primary subnet failover") } } } -func (h *Headscale) expireEphemeralNodesWorker() { - users, err := h.db.ListUsers() - if err != nil { - log.Error().Err(err).Msg("Error listing users") - - return - } - - for _, user := range users { - machines, err := h.db.ListMachinesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing machines in user") - - return - } - - expiredFound := false - for _, machine := range machines { - if machine.isEphemeral() && machine.LastSeen != nil && - time.Now(). - After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { - expiredFound = true - log.Info(). - Str("machine", machine.Hostname). - Msg("Ephemeral client removed from database") - - err = h.db.db.Unscoped().Delete(machine).Error - if err != nil { - log.Error(). - Err(err). - Str("machine", machine.Hostname). - Msg("🤮 Cannot delete ephemeral machine from the database") - } - } - } - - if expiredFound { - h.setLastStateChangeToNow() - } - } -} - -func (h *Headscale) expireExpiredMachinesWorker() { - users, err := h.db.ListUsers() - if err != nil { - log.Error().Err(err).Msg("Error listing users") - - return - } - - for _, user := range users { - machines, err := h.db.ListMachinesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing machines in user") - - return - } - - expiredFound := false - for index, machine := range machines { - if machine.isExpired() && - machine.Expiry.After(h.getLastStateChange(user)) { - expiredFound = true - - err := h.db.ExpireMachine(&machines[index]) - if err != nil { - log.Error(). - Err(err). - Str("machine", machine.Hostname). - Str("name", machine.GivenName). - Msg("🤮 Cannot expire machine") - } else { - log.Info(). - Str("machine", machine.Hostname). - Str("name", machine.GivenName). - Msg("Machine successfully expired") - } - } - } - - if expiredFound { - h.setLastStateChangeToNow() - } - } -} - func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, @@ -565,6 +487,8 @@ func (h *Headscale) Serve() error { go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) } + // TODO(kradalby): These should have cancel channels and be cleaned + // up on shutdown. go h.expireEphemeralNodes(updateInterval) go h.expireExpiredMachines(updateInterval) @@ -774,10 +698,12 @@ func (h *Headscale) Serve() error { if h.cfg.ACL.PolicyPath != "" { aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) - err := h.LoadACLPolicyFromPath(aclPath) + pol, err := policy.LoadACLPolicyFromPath(aclPath) if err != nil { log.Error().Err(err).Msg("Failed to reload ACL policy") } + + h.ACLPolicy = pol log.Info(). Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -824,12 +750,12 @@ func (h *Headscale) Serve() error { close(h.stateUpdateChan) close(h.cancelStateUpdateChan) + <-h.cancelPolicyUpdateChan + close(h.policyUpdateChan) + close(h.cancelPolicyUpdateChan) + // Close db connections - db, err := h.db.db.DB() - if err != nil { - log.Error().Err(err).Msg("Failed to get db handle") - } - err = db.Close() + err = h.db.Close() if err != nil { log.Error().Err(err).Msg("Failed to close db") } @@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() { } } +// TODO(kradalby): baby steps, make this more robust. +func (h *Headscale) watchPolicyChannel() { + for { + select { + case <-h.policyUpdateChan: + machines, err := h.db.ListMachines() + if err != nil { + log.Error().Err(err).Msg("failed to fetch machines during policy update") + } + + rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain) + if err != nil { + log.Error().Err(err).Msg("failed to update ACL rules") + } + + h.aclRules = rules + h.sshPolicy = sshPolicy + + case <-h.cancelPolicyUpdateChan: + return + } + } +} + func (h *Headscale) setLastStateChangeToNow() { var err error @@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() { } } -func (h *Headscale) getLastStateChange(users ...User) time.Time { +func (h *Headscale) getLastStateChange(users ...types.User) time.Time { times := []time.Time{} // getLastStateChange takes a list of users as a "filter", if no users diff --git a/hscontrol/db/acls_test.go b/hscontrol/db/acls_test.go new file mode 100644 index 0000000000..884b6c5cc7 --- /dev/null +++ b/hscontrol/db/acls_test.go @@ -0,0 +1,480 @@ +package db + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "tailscale.com/envknob" + "tailscale.com/tailcfg" +) + +// TODO(kradalby): +// Convert these tests to being non-database dependent and table driven. They are +// very verbose, and dont really need the database. + +func (s *Suite) TestSshRules(c *check.C) { + envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") + + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + aclPolicy := &policy.ACLPolicy{ + Groups: policy.Groups{ + "group:test": []string{"user1"}, + }, + Hosts: policy.Hosts{ + "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), + }, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + SSHs: []policy.SSH{ + { + Action: "accept", + Sources: []string{"group:test"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"client"}, + Users: []string{"autogroup:nonroot"}, + }, + }, + } + + _, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false) + + c.Assert(err, check.IsNil) + c.Assert(sshPolicy, check.NotNil) + c.Assert(sshPolicy.Rules, check.HasLen, 2) + c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) + c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1) + c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") + + c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) + c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1) + c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") +} + +// this test should validate that we can expand a group in a TagOWner section and +// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. +// the tag is matched in the Sources section. +func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"tag:test"}, + Destinations: []string{"*:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") +} + +// this test should validate that we can expand a group in a TagOWner section and +// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. +// the tag is matched in the Destinations section. +func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:test"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + Groups: policy.Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"tag:test:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") +} + +// need a test with: +// tag on a host that isn't owned by a tag owners. So the user +// of the host should be valid. +func (s *Suite) TestInvalidTagValidUser(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "testmachine") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "testmachine", + RequestTags: []string{"tag:foo"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + TagOwners: policy.TagOwners{"tag:test": []string{"user1"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"*:*"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") +} + +// tag on a host is owned by a tag owner, the tag is valid. +// an ACL rule is matching the tag to a user. It should not be valid since the +// host should be tied to the tag now. +func (s *Suite) TestValidTagInvalidUser(c *check.C) { + user, err := db.CreateUser("user1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "webserver") + c.Assert(err, check.NotNil) + hostInfo := tailcfg.Hostinfo{ + OS: "centos", + Hostname: "webserver", + RequestTags: []string{"tag:webapp"}, + } + + machine := types.Machine{ + ID: 1, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "webserver", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user1", "user") + hostInfo2 := tailcfg.Hostinfo{ + OS: "debian", + Hostname: "Hostname", + } + c.Assert(err, check.NotNil) + machine = types.Machine{ + ID: 2, + MachineKey: "56789", + NodeKey: "bar2", + DiscoKey: "faab", + Hostname: "user", + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo(hostInfo2), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + pol := &policy.ACLPolicy{ + TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"user1"}, + Destinations: []string{"tag:webapp:80,443"}, + }, + }, + } + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") + c.Assert(rules[0].DstPorts, check.HasLen, 2) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) + c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") + c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) + c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) + c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") +} + +func (s *Suite) TestPortUser(c *check.C) { + user, err := db.CreateUser("testuser") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("testuser", "testmachine") + c.Assert(err, check.NotNil) + ips, _ := db.getAvailableIPs() + machine := types.Machine{ + ID: 0, + MachineKey: "12345", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: ips, + AuthKeyID: uint(pak.ID), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + acl := []byte(` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} + `) + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(err, check.IsNil) + c.Assert(rules, check.NotNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") +} + +func (s *Suite) TestPortGroup(c *check.C) { + user, err := db.CreateUser("testuser") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("testuser", "testmachine") + c.Assert(err, check.NotNil) + ips, _ := db.getAvailableIPs() + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: ips, + AuthKeyID: uint(pak.ID), + } + err = db.MachineSave(&machine) + c.Assert(err, check.IsNil) + + acl := []byte(` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} + `) + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + rules, _, err := policy.GenerateFilterRules(pol, machines, false) + c.Assert(err, check.IsNil) + + c.Assert(rules, check.NotNil) + + c.Assert(rules, check.HasLen, 1) + c.Assert(rules[0].DstPorts, check.HasLen, 1) + c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) + c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) + c.Assert(rules[0].SrcIPs, check.HasLen, 1) + c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") + c.Assert(len(ips), check.Equals, 1) + c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") +} diff --git a/hscontrol/addresses.go b/hscontrol/db/addresses.go similarity index 87% rename from hscontrol/addresses.go rename to hscontrol/db/addresses.go index 7f78935f8e..1a7d35defc 100644 --- a/hscontrol/addresses.go +++ b/hscontrol/db/addresses.go @@ -3,21 +3,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package hscontrol +package db import ( "errors" "fmt" "net/netip" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" ) var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") -func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) { - var ips MachineAddresses +func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { + var ips types.MachineAddresses var err error for _, ipPrefix := range hsdb.ipPrefixes { var ip *netip.Addr @@ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string - hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) + hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices) var ips netipx.IPSetBuilder for _, slice := range addressesSlices { - var machineAddresses MachineAddresses + var machineAddresses types.MachineAddresses err := machineAddresses.Scan(slice) if err != nil { return &netipx.IPSet{}, fmt.Errorf( diff --git a/hscontrol/addresses_test.go b/hscontrol/db/addresses_test.go similarity index 71% rename from hscontrol/addresses_test.go rename to hscontrol/db/addresses_test.go index f3be93aab1..12891480bb 100644 --- a/hscontrol/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -1,14 +1,16 @@ -package hscontrol +package db import ( "net/netip" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "gopkg.in/check.v1" ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) @@ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { } func (s *Suite) TestGetUsedIps(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) - user, err := app.db.CreateUser("test-ip") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.db.Save(&machine) + db.db.Save(&machine) - usedIps, err := app.db.getUsedIPs() + usedIps, err := db.getUsedIPs() c.Assert(err, check.IsNil) @@ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) c.Assert(usedIps.Contains(expected), check.Equals, true) - machine1, err := app.db.GetMachineByID(0) + machine1, err := db.GetMachineByID(0) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(machine1.IPAddresses[0], check.Equals, expected) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := app.db.CreateUser("test-ip-multi") + user, err := db.CreateUser("test-ip-multi") c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { - app.db.ipAllocationMutex.Lock() + db.ipAllocationMutex.Lock() - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: uint64(index), MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), IPAddresses: ips, } - app.db.db.Save(&machine) + db.db.Save(&machine) - app.db.ipAllocationMutex.Unlock() + db.ipAllocationMutex.Unlock() } - usedIps, err := app.db.getUsedIPs() + usedIps, err := db.getUsedIPs() c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs - machine1, err := app.db.GetMachineByID(1) + machine1, err := db.GetMachineByID(1) c.Assert(err, check.IsNil) c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert( @@ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { netip.MustParseAddr("10.27.0.1"), ) - machine50, err := app.db.GetMachineByID(50) + machine50, err := db.GetMachineByID(50) c.Assert(err, check.IsNil) c.Assert(len(machine50.IPAddresses), check.Equals, 1) c.Assert( @@ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) { ) expectedNextIP := netip.MustParseAddr("10.27.1.95") - nextIP, err := app.db.getAvailableIPs() + nextIP, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP), check.Equals, 1) @@ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) { // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := app.db.getAvailableIPs() + nextIP2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(nextIP2), check.Equals, 1) c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ips, err := app.db.getAvailableIPs() + ips, err := db.getAvailableIPs() c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(ips[0].String(), check.Equals, expected.String()) - user, err := app.db.CreateUser("test-ip") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "testmachine") + _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testmachine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - ips2, err := app.db.getAvailableIPs() + ips2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) c.Assert(len(ips2), check.Equals, 1) c.Assert(ips2[0].String(), check.Equals, expected.String()) + + c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/api_key.go b/hscontrol/db/api_key.go similarity index 64% rename from hscontrol/api_key.go rename to hscontrol/db/api_key.go index bf2ccf3942..4e4030ebfe 100644 --- a/hscontrol/api_key.go +++ b/hscontrol/db/api_key.go @@ -1,4 +1,4 @@ -package hscontrol +package db import ( "errors" @@ -6,10 +6,9 @@ import ( "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" - "google.golang.org/protobuf/types/known/timestamppb" ) const ( @@ -19,22 +18,10 @@ const ( var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") -// APIKey describes the datamodel for API keys used to remotely authenticate with -// headscale. -type APIKey struct { - ID uint64 `gorm:"primary_key"` - Prefix string `gorm:"uniqueIndex"` - Hash []byte - - CreatedAt *time.Time - Expiration *time.Time - LastSeen *time.Time -} - // CreateAPIKey creates a new ApiKey in a user, and returns it. func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, -) (string, *APIKey, error) { +) (string, *types.APIKey, error) { prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey( return "", nil, err } - key := APIKey{ + key := types.APIKey{ Prefix: prefix, Hash: hash, Expiration: expiration, @@ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey( } // ListAPIKeys returns the list of ApiKeys for a user. -func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { - keys := []APIKey{} +func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { + keys := []types.APIKey{} if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err } @@ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) { } // GetAPIKey returns a ApiKey for a given key. -func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { - key := APIKey{} +func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { + key := types.APIKey{} if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) { } // GetAPIKeyByID returns a ApiKey for a given id. -func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { - key := APIKey{} - if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { +func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + key := types.APIKey{} + if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. -func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { +func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error { } // ExpireAPIKey marks a ApiKey as expired. -func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error { +func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { return true, nil } - -func (key *APIKey) toProto() *v1.ApiKey { - protoKey := v1.ApiKey{ - Id: key.ID, - Prefix: key.Prefix, - } - - if key.Expiration != nil { - protoKey.Expiration = timestamppb.New(*key.Expiration) - } - - if key.CreatedAt != nil { - protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) - } - - if key.LastSeen != nil { - protoKey.LastSeen = timestamppb.New(*key.LastSeen) - } - - return &protoKey -} diff --git a/hscontrol/api_key_test.go b/hscontrol/db/api_key_test.go similarity index 63% rename from hscontrol/api_key_test.go rename to hscontrol/db/api_key_test.go index 007b5d1642..0fc42c5a50 100644 --- a/hscontrol/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -1,4 +1,4 @@ -package hscontrol +package db import ( "time" @@ -7,7 +7,7 @@ import ( ) func (*Suite) TestCreateAPIKey(c *check.C) { - apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil) + apiKeyStr, apiKey, err := db.CreateAPIKey(nil) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) @@ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) { c.Assert(apiKey.Hash, check.NotNil) c.Assert(apiKeyStr, check.Not(check.Equals), "") - _, err = app.db.ListAPIKeys() + _, err = db.ListAPIKeys() c.Assert(err, check.IsNil) - keys, err := app.db.ListAPIKeys() + keys, err := db.ListAPIKeys() c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { - key, err := app.db.GetAPIKey("does-not-exist") + key, err := db.GetAPIKey("does-not-exist") c.Assert(err, check.NotNil) c.Assert(key, check.IsNil) } func (*Suite) TestValidateAPIKeyOk(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, false) now := time.Now() - apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now) + apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - validNow, err := app.db.ValidateAPIKey(apiKeyStrNow) + validNow, err := db.ValidateAPIKey(apiKeyStrNow) c.Assert(err, check.IsNil) c.Assert(validNow, check.Equals, false) - validSilly, err := app.db.ValidateAPIKey("nota.validkey") + validSilly, err := db.ValidateAPIKey("nota.validkey") c.Assert(err, check.NotNil) c.Assert(validSilly, check.Equals, false) - validWithErr, err := app.db.ValidateAPIKey("produceerrorkey") + validWithErr, err := db.ValidateAPIKey("produceerrorkey") c.Assert(err, check.NotNil) c.Assert(validWithErr, check.Equals, false) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestExpireAPIKey(c *check.C) { nowPlus2 := time.Now().Add(2 * time.Hour) - apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2) + apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) c.Assert(err, check.IsNil) c.Assert(apiKey, check.NotNil) - valid, err := app.db.ValidateAPIKey(apiKeyStr) + valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) - err = app.db.ExpireAPIKey(apiKey) + err = db.ExpireAPIKey(apiKey) c.Assert(err, check.IsNil) c.Assert(apiKey.Expiration, check.NotNil) - notValid, err := app.db.ValidateAPIKey(apiKeyStr) + notValid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) + + c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/db.go b/hscontrol/db/db.go similarity index 63% rename from hscontrol/db.go rename to hscontrol/db/db.go index e80a3c3ed9..bc6de089bd 100644 --- a/hscontrol/db.go +++ b/hscontrol/db/db.go @@ -1,9 +1,7 @@ -package hscontrol +package db import ( "context" - "database/sql/driver" - "encoding/json" "errors" "fmt" "net/netip" @@ -11,11 +9,12 @@ import ( "time" "github.com/glebarez/sqlite" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" - "tailscale.com/tailcfg" ) const ( @@ -26,7 +25,6 @@ const ( var ( errValueNotFound = errors.New("not found") - ErrCannotParsePrefix = errors.New("cannot parse prefix") errDatabaseNotSupported = errors.New("database type not supported") ) @@ -38,8 +36,9 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifyStateChan chan<- struct{} + db *gorm.DB + notifyStateChan chan<- struct{} + notifyPolicyChan chan<- struct{} ipAllocationMutex sync.Mutex @@ -54,6 +53,7 @@ func NewHeadscaleDatabase( dbType, connectionAddr string, stripEmailDomain, debug bool, notifyStateChan chan<- struct{}, + notifyPolicyChan chan<- struct{}, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { @@ -63,8 +63,9 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - db: dbConn, - notifyStateChan: notifyStateChan, + db: dbConn, + notifyStateChan: notifyStateChan, + notifyPolicyChan: notifyPolicyChan, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -79,30 +80,30 @@ func NewHeadscaleDatabase( _ = dbConn.Migrator().RenameTable("namespaces", "users") - err = dbConn.AutoMigrate(User{}) + err = dbConn.AutoMigrate(types.User{}) if err != nil { return nil, err } - _ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id") - _ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id") + _ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - _ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses") - _ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname") // GivenName is used as the primary source of DNS names, make sure // the field is populated and normalized if it was not when the // machine was registered. - _ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name") + _ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name") // If the Machine table has a column for registered, // find all occourences of "false" and drop them. Then // remove the column. - if dbConn.Migrator().HasColumn(&Machine{}, "registered") { + if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") { log.Info(). Msg(`Database has legacy "registered" column in machine, removing...`) - machines := Machines{} + machines := types.Machines{} if err := dbConn.Not("registered").Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } @@ -112,7 +113,7 @@ func NewHeadscaleDatabase( Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). Msg("Deleting unregistered machine") - if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil { + if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil { log.Error(). Err(err). Str("machine", machine.Hostname). @@ -121,23 +122,23 @@ func NewHeadscaleDatabase( } } - err := dbConn.Migrator().DropColumn(&Machine{}, "registered") + err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered") if err != nil { log.Error().Err(err).Msg("Error dropping registered column") } } - err = dbConn.AutoMigrate(&Route{}) + err = dbConn.AutoMigrate(&types.Route{}) if err != nil { return nil, err } - if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") { + if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") { log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...") type MachineAux struct { ID uint64 - EnabledRoutes IPPrefixes + EnabledRoutes types.IPPrefixes } machinesAux := []MachineAux{} @@ -157,8 +158,8 @@ func NewHeadscaleDatabase( } err = dbConn.Preload("Machine"). - Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). - First(&Route{}). + Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). + First(&types.Route{}). Error if err == nil { log.Info(). @@ -168,11 +169,11 @@ func NewHeadscaleDatabase( continue } - route := Route{ + route := types.Route{ MachineID: machine.ID, Advertised: true, Enabled: true, - Prefix: IPPrefix(prefix), + Prefix: types.IPPrefix(prefix), } if err := dbConn.Create(&route).Error; err != nil { log.Error().Err(err).Msg("Error creating route") @@ -185,26 +186,26 @@ func NewHeadscaleDatabase( } } - err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes") + err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes") if err != nil { log.Error().Err(err).Msg("Error dropping enabled_routes column") } } - err = dbConn.AutoMigrate(&Machine{}) + err = dbConn.AutoMigrate(&types.Machine{}) if err != nil { return nil, err } - if dbConn.Migrator().HasColumn(&Machine{}, "given_name") { - machines := Machines{} + if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") { + machines := types.Machines{} if err := dbConn.Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") } for item, machine := range machines { if machine.GivenName == "" { - normalizedHostname, err := NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRules( machine.Hostname, stripEmailDomain, ) @@ -233,19 +234,19 @@ func NewHeadscaleDatabase( return nil, err } - err = dbConn.AutoMigrate(&PreAuthKey{}) + err = dbConn.AutoMigrate(&types.PreAuthKey{}) if err != nil { return nil, err } - err = dbConn.AutoMigrate(&PreAuthKeyACLTag{}) + err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{}) if err != nil { return nil, err } _ = dbConn.Migrator().DropTable("shared_machines") - err = dbConn.AutoMigrate(&APIKey{}) + err = dbConn.AutoMigrate(&types.APIKey{}) if err != nil { return nil, err } @@ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error { return nil } -func (hsdb *HSDatabase) pingDB(ctx context.Context) error { +func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() sqlDB, err := hsdb.db.DB() @@ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error { return sqlDB.PingContext(ctx) } -// This is a "wrapper" type around tailscales -// Hostinfo to allow us to add database "serialization" -// methods. This allows us to use a typed values throughout -// the code and not have to marshal/unmarshal and error -// check all over the code. -type HostInfo tailcfg.Hostinfo - -func (hi *HostInfo) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, hi) - - case string: - return json.Unmarshal([]byte(value), hi) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (hi HostInfo) Value() (driver.Value, error) { - bytes, err := json.Marshal(hi) - - return string(bytes), err -} - -type IPPrefix netip.Prefix - -func (i *IPPrefix) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - prefix, err := netip.ParsePrefix(value) - if err != nil { - return err - } - *i = IPPrefix(prefix) - - return nil - default: - return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefix) Value() (driver.Value, error) { - prefixStr := netip.Prefix(i).String() - - return prefixStr, nil -} - -type IPPrefixes []netip.Prefix - -func (i *IPPrefixes) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefixes) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - -type StringList []string - -func (i *StringList) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) +func (hsdb *HSDatabase) Close() error { + db, err := hsdb.db.DB() + if err != nil { + return err } -} - -// Value return json value, implement driver.Valuer interface. -func (i StringList) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - return string(bytes), err + return db.Close() } diff --git a/hscontrol/machine.go b/hscontrol/db/machine.go similarity index 58% rename from hscontrol/machine.go rename to hscontrol/db/machine.go index 846112b15b..a8d3569e99 100644 --- a/hscontrol/machine.go +++ b/hscontrol/db/machine.go @@ -1,7 +1,6 @@ -package hscontrol +package db import ( - "database/sql/driver" "errors" "fmt" "net/netip" @@ -10,13 +9,12 @@ import ( "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "github.com/samber/lo" - "go4.org/netipx" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -25,13 +23,12 @@ import ( const ( MachineGivenNameHashLength = 8 MachineGivenNameTrimSize = 2 - maxHostnameLength = 255 + MaxHostnameLength = 255 ) var ( ErrMachineNotFound = errors.New("machine not found") ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine") - ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") ErrMachineNotFoundRegistrationCache = errors.New( "machine not found in registration cache", ) @@ -42,193 +39,27 @@ var ( ) ) -// Machine is a Headscale client. -type Machine struct { - ID uint64 `gorm:"primary_key"` - MachineKey string `gorm:"type:varchar(64);unique_index"` - NodeKey string - DiscoKey string - IPAddresses MachineAddresses - - // Hostname represents the name given by the Tailscale - // client during registration - Hostname string - - // Givenname represents either: - // a DNS normalized version of Hostname - // a valid name set by the User - // - // GivenName is the name used in all DNS related - // parts of headscale. - GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"foreignKey:UserID"` - - RegisterMethod string - - ForcedTags StringList - - // TODO(kradalby): This seems like irrelevant information? - AuthKeyID uint - AuthKey *PreAuthKey - - LastSeen *time.Time - LastSuccessfulUpdate *time.Time - Expiry *time.Time - - HostInfo HostInfo - Endpoints StringList - - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type ( - Machines []Machine - MachinesP []*Machine -) - -type MachineAddresses []netip.Addr - -func (ma MachineAddresses) ToStringSlice() []string { - strSlice := make([]string, 0, len(ma)) - for _, addr := range ma { - strSlice = append(strSlice, addr.String()) - } - - return strSlice -} - -// AppendToIPSet adds the individual ips in MachineAddresses to a -// given netipx.IPSetBuilder. -func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { - for _, ip := range ma { - build.Add(ip) - } -} - -func (ma *MachineAddresses) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - addresses := strings.Split(value, ",") - *ma = (*ma)[:0] - for _, addr := range addresses { - if len(addr) < 1 { - continue - } - parsed, err := netip.ParseAddr(addr) - if err != nil { - return err - } - *ma = append(*ma, parsed) - } - - return nil - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (ma MachineAddresses) Value() (driver.Value, error) { - addresses := strings.Join(ma.ToStringSlice(), ",") - - return addresses, nil -} - -// isExpired returns whether the machine registration has expired. -func (machine Machine) isExpired() bool { - // If Expiry is not set, the client has not indicated that - // it wants an expiry time, it is therefor considered - // to mean "not expired" - if machine.Expiry == nil || machine.Expiry.IsZero() { - return false - } - - return time.Now().UTC().After(*machine.Expiry) -} - -// isOnline returns if the machine is connected to Headscale. -// This is really a naive implementation, as we don't really see -// if there is a working connection between the client and the server. -func (machine *Machine) isOnline() bool { - if machine.LastSeen == nil { - return false - } - - if machine.isExpired() { - return false - } - - return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) -} - -// isEphemeral returns if the machine is registered as an Ephemeral node. -// https://tailscale.com/kb/1111/ephemeral-nodes/ -func (machine *Machine) isEphemeral() bool { - return machine.AuthKey != nil && machine.AuthKey.Ephemeral -} - -func (machine *Machine) canAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { - for _, rule := range filter { - // TODO(kradalby): Cache or pregen this - matcher := MatchFromFilterRule(rule) - - if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { - continue - } - - if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { - return true - } - } - - return false -} - // filterMachinesByACL wrapper function to not have devs pass around locks and maps // related to the application outside of tests. func (hsdb *HSDatabase) filterMachinesByACL( aclRules []tailcfg.FilterRule, - currentMachine *Machine, peers Machines) Machines { - return filterMachinesByACL(currentMachine, peers, aclRules) -} - -// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine. -func filterMachinesByACL( - machine *Machine, - machines Machines, - filter []tailcfg.FilterRule, -) Machines { - result := Machines{} - - for index, peer := range machines { - if peer.ID == machine.ID { - continue - } - - if machine.canAccess(filter, &machines[index]) || peer.canAccess(filter, machine) { - result = append(result, peer) - } - } - - return result + currentMachine *types.Machine, peers types.Machines, +) types.Machines { + return policy.FilterMachinesByACL(currentMachine, peers, aclRules) } -func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { +func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). Msg("Finding direct peers") - machines := Machines{} + machines := types.Machines{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?", machine.NodeKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") - return Machines{}, err + return types.Machines{}, err } sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID }) @@ -242,22 +73,21 @@ func (hsdb *HSDatabase) ListPeers(machine *Machine) (Machines, error) { } func (hsdb *HSDatabase) getPeers( - aclPolicy *ACLPolicy, aclRules []tailcfg.FilterRule, - machine *Machine, -) (Machines, error) { - var peers Machines + machine *types.Machine, +) (types.Machines, error) { + var peers types.Machines var err error // If ACLs rules are defined, filter visible host list with the ACLs // else use the classic user scope - if aclPolicy != nil { - var machines []Machine + if len(aclRules) > 0 { + var machines []types.Machine machines, err = hsdb.ListMachines() if err != nil { log.Error().Err(err).Msg("Error retrieving list of machines") - return Machines{}, err + return types.Machines{}, err } peers = hsdb.filterMachinesByACL(aclRules, machine, machines) } else { @@ -268,7 +98,7 @@ func (hsdb *HSDatabase) getPeers( Err(err). Msg("Cannot fetch peers") - return Machines{}, err + return types.Machines{}, err } } @@ -283,20 +113,19 @@ func (hsdb *HSDatabase) getPeers( return peers, nil } -func (hsdb *HSDatabase) getValidPeers( - aclPolicy *ACLPolicy, +func (hsdb *HSDatabase) GetValidPeers( aclRules []tailcfg.FilterRule, - machine *Machine, -) (Machines, error) { - validPeers := make(Machines, 0) + machine *types.Machine, +) (types.Machines, error) { + validPeers := make(types.Machines, 0) - peers, err := hsdb.getPeers(aclPolicy, aclRules, machine) + peers, err := hsdb.getPeers(aclRules, machine) if err != nil { - return Machines{}, err + return types.Machines{}, err } for _, peer := range peers { - if !peer.isExpired() { + if !peer.IsExpired() { validPeers = append(validPeers, peer) } } @@ -304,8 +133,8 @@ func (hsdb *HSDatabase) getValidPeers( return validPeers, nil } -func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { - machines := []Machine{} +func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { + machines := []types.Machine{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { return nil, err } @@ -313,8 +142,8 @@ func (hsdb *HSDatabase) ListMachines() ([]Machine, error) { return machines, nil } -func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, error) { - machines := []Machine{} +func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { + machines := types.Machines{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil { return nil, err } @@ -323,7 +152,7 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) ([]Machine, er } // GetMachine finds a Machine by name and user and returns the Machine struct. -func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { +func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -339,7 +168,10 @@ func (hsdb *HSDatabase) GetMachine(user string, name string) (*Machine, error) { } // GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct. -func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*Machine, error) { +func (hsdb *HSDatabase) GetMachineByGivenName( + user string, + givenName string, +) (*types.Machine, error) { machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -355,9 +187,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName(user string, givenName string) (*M } // GetMachineByID finds a Machine by ID and returns the Machine struct. -func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { - m := Machine{} - if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&Machine{ID: id}).First(&m); result.Error != nil { +func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { + m := types.Machine{} + if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&types.Machine{ID: id}).First(&m); result.Error != nil { return nil, result.Error } @@ -367,8 +199,8 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*Machine, error) { // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, -) (*Machine, error) { - m := Machine{} +) (*types.Machine, error) { + m := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } @@ -379,8 +211,8 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( // GetMachineByNodeKey finds a Machine by its current NodeKey. func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, -) (*Machine, error) { - machine := Machine{} +) (*types.Machine, error) { + machine := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?", util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { return nil, result.Error @@ -392,8 +224,8 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( // GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, -) (*Machine, error) { - machine := Machine{} +) (*types.Machine, error) { + machine := types.Machine{} if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?", util.MachinePublicKeyStripPrefix(machineKey), util.NodePublicKeyStripPrefix(nodeKey), @@ -404,9 +236,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( return &machine, nil } +// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { +func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -416,13 +249,9 @@ func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *Machine) error { // SetTags takes a Machine struct pointer and update the forced tags. func (hsdb *HSDatabase) SetTags( - machine *Machine, + machine *types.Machine, tags []string, - // TODO(kradalby): This is a temporary measure to be able to detach the - // database completely from the global h. In the future, as part of this - // reorg, the rules will be generated on a per node basis, and not be prone - // to throwing error at save. - updateACL func() error) error { +) error { newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { @@ -430,10 +259,8 @@ func (hsdb *HSDatabase) SetTags( } } machine.ForcedTags = newTags - if err := updateACL(); err != nil && !errors.Is(err, errEmptyPolicy) { - return err - } + hsdb.notifyPolicyChan <- struct{}{} hsdb.notifyStateChange() if err := hsdb.db.Save(machine).Error; err != nil { @@ -444,7 +271,7 @@ func (hsdb *HSDatabase) SetTags( } // ExpireMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { +func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { now := time.Now() machine.Expiry = &now @@ -459,8 +286,8 @@ func (hsdb *HSDatabase) ExpireMachine(machine *Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. -func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { - err := CheckForFQDNRules( +func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { + err := util.CheckForFQDNRules( newName, ) if err != nil { @@ -484,8 +311,8 @@ func (hsdb *HSDatabase) RenameMachine(machine *Machine, newName string) error { return nil } -// RefreshMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error { +// RefreshMachine takes a Machine struct and a new expiry time. +func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { now := time.Now() machine.LastSuccessfulUpdate = &now @@ -504,7 +331,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *Machine, expiry time.Time) error } // DeleteMachine softs deletes a Machine from the database. -func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { +func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -517,8 +344,8 @@ func (hsdb *HSDatabase) DeleteMachine(machine *Machine) error { return nil } -func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { - return hsdb.db.Updates(Machine{ +func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { + return hsdb.db.Updates(types.Machine{ ID: machine.ID, LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, @@ -526,7 +353,7 @@ func (hsdb *HSDatabase) TouchMachine(machine *Machine) error { } // HardDeleteMachine hard deletes a Machine from the database. -func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { +func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -539,12 +366,7 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *Machine) error { return nil } -// GetHostInfo returns a Hostinfo struct for the machine. -func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { - return tailcfg.Hostinfo(machine.HostInfo) -} - -func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool { +func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool { if err := hsdb.UpdateMachineFromDatabase(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. @@ -570,291 +392,13 @@ func (hsdb *HSDatabase) isOutdated(machine *Machine, lastChange time.Time) bool return lastUpdate.Before(lastChange) } -func (machine Machine) String() string { - return machine.Hostname -} - -func (machines Machines) String() string { - temp := make([]string, len(machines)) - - for index, machine := range machines { - temp[index] = machine.Hostname - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -// TODO(kradalby): Remove when we have generics... -func (machines MachinesP) String() string { - temp := make([]string, len(machines)) - - for index, machine := range machines { - temp[index] = machine.Hostname - } - - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) -} - -func (hsdb *HSDatabase) toNodes( - machines Machines, - aclPolicy *ACLPolicy, - baseDomain string, - dnsConfig *tailcfg.DNSConfig, -) ([]*tailcfg.Node, error) { - nodes := make([]*tailcfg.Node, len(machines)) - - for index, machine := range machines { - node, err := hsdb.toNode(machine, aclPolicy, baseDomain, dnsConfig) - if err != nil { - return nil, err - } - - nodes[index] = node - } - - return nodes, nil -} - -// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes -// as per the expected behaviour in the official SaaS. -func (hsdb *HSDatabase) toNode( - machine Machine, - aclPolicy *ACLPolicy, - baseDomain string, - dnsConfig *tailcfg.DNSConfig, -) (*tailcfg.Node, error) { - var nodeKey key.NodePublic - err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) - if err != nil { - log.Trace(). - Caller(). - Str("node_key", machine.NodeKey). - Msgf("Failed to parse node public key from hex") - - return nil, fmt.Errorf("failed to parse node public key: %w", err) - } - - var machineKey key.MachinePublic - // MachineKey is only used in the legacy protocol - if machine.MachineKey != "" { - err = machineKey.UnmarshalText( - []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse machine public key: %w", err) - } - } - - var discoKey key.DiscoPublic - if machine.DiscoKey != "" { - err := discoKey.UnmarshalText( - []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse disco public key: %w", err) - } - } else { - discoKey = key.DiscoPublic{} - } - - addrs := []netip.Prefix{} - for _, machineAddress := range machine.IPAddresses { - ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen()) - addrs = append(addrs, ip) - } - - allowedIPs := append( - []netip.Prefix{}, - addrs...) // we append the node own IP, as it is required by the clients - - primaryRoutes, err := hsdb.getMachinePrimaryRoutes(&machine) - if err != nil { - return nil, err - } - primaryPrefixes := Routes(primaryRoutes).toPrefixes() - - machineRoutes, err := hsdb.GetMachineRoutes(&machine) - if err != nil { - return nil, err - } - for _, route := range machineRoutes { - if route.Enabled && (route.IsPrimary || route.isExitRoute()) { - allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) - } - } - - var derp string - if machine.HostInfo.NetInfo != nil { - derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) - } else { - derp = "127.3.3.40:0" // Zero means disconnected or unknown. - } - - var keyExpiry time.Time - if machine.Expiry != nil { - keyExpiry = *machine.Expiry - } else { - keyExpiry = time.Time{} - } - - var hostname string - if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - hostname = fmt.Sprintf( - "%s.%s.%s", - machine.GivenName, - machine.User.Name, - baseDomain, - ) - if len(hostname) > maxHostnameLength { - return nil, fmt.Errorf( - "hostname %q is too long it cannot except 255 ASCII chars: %w", - hostname, - ErrHostnameTooLong, - ) - } - } else { - hostname = machine.GivenName - } - - hostInfo := machine.GetHostInfo() - - online := machine.isOnline() - - tags, _ := getTags(aclPolicy, machine, hsdb.stripEmailDomain) - tags = lo.Uniq(append(tags, machine.ForcedTags...)) - - node := tailcfg.Node{ - ID: tailcfg.NodeID(machine.ID), // this is the actual ID - StableID: tailcfg.StableNodeID( - strconv.FormatUint(machine.ID, util.Base10), - ), // in headscale, unlike tailcontrol server, IDs are permanent - Name: hostname, - - User: tailcfg.UserID(machine.UserID), - - Key: nodeKey, - KeyExpiry: keyExpiry, - - Machine: machineKey, - DiscoKey: discoKey, - Addresses: addrs, - AllowedIPs: allowedIPs, - Endpoints: machine.Endpoints, - DERP: derp, - Hostinfo: hostInfo.View(), - Created: machine.CreatedAt, - - Tags: tags, - - PrimaryRoutes: primaryPrefixes, - - LastSeen: machine.LastSeen, - Online: &online, - KeepAlive: true, - MachineAuthorized: !machine.isExpired(), - - Capabilities: []string{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - }, - } - - return &node, nil -} - -func (machine *Machine) toProto() *v1.Machine { - machineProto := &v1.Machine{ - Id: machine.ID, - MachineKey: machine.MachineKey, - - NodeKey: machine.NodeKey, - DiscoKey: machine.DiscoKey, - IpAddresses: machine.IPAddresses.ToStringSlice(), - Name: machine.Hostname, - GivenName: machine.GivenName, - User: machine.User.toProto(), - ForcedTags: machine.ForcedTags, - Online: machine.isOnline(), - - // TODO(kradalby): Implement register method enum converter - // RegisterMethod: , - - CreatedAt: timestamppb.New(machine.CreatedAt), - } - - if machine.AuthKey != nil { - machineProto.PreAuthKey = machine.AuthKey.toProto() - } - - if machine.LastSeen != nil { - machineProto.LastSeen = timestamppb.New(*machine.LastSeen) - } - - if machine.LastSuccessfulUpdate != nil { - machineProto.LastSuccessfulUpdate = timestamppb.New( - *machine.LastSuccessfulUpdate, - ) - } - - if machine.Expiry != nil { - machineProto.Expiry = timestamppb.New(*machine.Expiry) - } - - return machineProto -} - -// getTags will return the tags of the current machine. -// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. -// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. -func getTags( - aclPolicy *ACLPolicy, - machine Machine, - stripEmailDomain bool, -) ([]string, []string) { - validTags := make([]string, 0) - invalidTags := make([]string, 0) - if aclPolicy == nil { - return validTags, invalidTags - } - validTagMap := make(map[string]bool) - invalidTagMap := make(map[string]bool) - for _, tag := range machine.HostInfo.RequestTags { - owners, err := getTagOwners(aclPolicy, tag, stripEmailDomain) - if errors.Is(err, errInvalidTag) { - invalidTagMap[tag] = true - - continue - } - var found bool - for _, owner := range owners { - if machine.User.Name == owner { - found = true - } - } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true - } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) - } - - return validTags, invalidTags -} - func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( cache *cache.Cache, nodeKeyStr string, userName string, machineExpiry *time.Time, registrationMethod string, -) (*Machine, error) { +) (*types.Machine, error) { nodeKey := key.NodePublic{} err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) if err != nil { @@ -869,7 +413,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( Msg("Registering machine from API/CLI or auth callback") if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { - if registrationMachine, ok := machineInterface.(Machine); ok { + if registrationMachine, ok := machineInterface.(types.Machine); ok { user, err := hsdb.GetUser(userName) if err != nil { return nil, fmt.Errorf( @@ -909,8 +453,8 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (hsdb *HSDatabase) RegisterMachine(machine Machine, -) (*Machine, error) { +func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, +) (*types.Machine, error) { log.Debug(). Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). @@ -966,9 +510,44 @@ func (hsdb *HSDatabase) RegisterMachine(machine Machine, return &machine, nil } +// MachineSetNodeKey sets the node key of a machine and saves it to the database. +func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { + machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) + + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + +// MachineSetMachineKey sets the machine key of a machine and saves it to the database. +func (hsdb *HSDatabase) MachineSetMachineKey( + machine *types.Machine, + nodeKey key.MachinePublic, +) error { + machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) + + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + +// MachineSave saves a machine object to the database, prefer to use a specific save method rather +// than this. It is intended to be used when we are changing or. +func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { + if err := hsdb.db.Save(machine).Error; err != nil { + return err + } + + return nil +} + // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. -func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { - routes := []Route{} +func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { + routes := types.Routes{} err := hsdb.db. Preload("Machine"). @@ -992,8 +571,8 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, e } // GetEnabledRoutes returns the routes that are enabled for the machine. -func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { - routes := []Route{} +func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { + routes := types.Routes{} err := hsdb.db. Preload("Machine"). @@ -1017,7 +596,7 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, erro return prefixes, nil } -func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool { +func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false @@ -1040,7 +619,7 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *Machine, routeStr string) bool } // enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) error { +func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -1068,16 +647,16 @@ func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) erro // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { - route := Route{} + route := types.Route{} err := hsdb.db.Preload("Machine"). - Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). + Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)). First(&route).Error if err == nil { route.Enabled = true // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) - if !route.isExitRoute() { + if !route.IsExitRoute() { route.IsPrimary = hsdb.isUniquePrefix(route) } @@ -1095,81 +674,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *Machine, routeStrs ...string) erro return nil } -// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. -func (hsdb *HSDatabase) EnableAutoApprovedRoutes(aclPolicy *ACLPolicy, machine *Machine) error { - if len(machine.IPAddresses) == 0 { - return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs - } - - routes := []Route{} - err := hsdb.db. - Preload("Machine"). - Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("machine", machine.Hostname). - Msg("Could not get advertised routes for machine") - - return err - } - - approvedRoutes := []Route{} - - for _, advertisedRoute := range routes { - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - log.Err(err). - Str("advertisedRoute", advertisedRoute.String()). - Uint64("machineId", machine.ID). - Msg("Failed to resolve autoApprovers for advertised route") - - return err - } - - for _, approvedAlias := range routeApprovers { - if approvedAlias == machine.User.Name { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } else { - approvedIps, err := aclPolicy.expandAlias([]Machine{*machine}, approvedAlias, hsdb.stripEmailDomain) - if err != nil { - log.Err(err). - Str("alias", approvedAlias). - Msg("Failed to expand alias when processing autoApprovers policy") - - return err - } - - // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first - if approvedIps.Contains(machine.IPAddresses[0]) { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } - } - } - } - - for i, approvedRoute := range approvedRoutes { - approvedRoutes[i].Enabled = true - err = hsdb.db.Save(&approvedRoutes[i]).Error - if err != nil { - log.Err(err). - Str("approvedRoute", approvedRoute.String()). - Uint64("machineId", machine.ID). - Msg("Failed to enable approved route") - - return err - } - } - - return nil -} - func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { - normalizedHostname, err := NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRules( suppliedName, hsdb.stripEmailDomain, ) @@ -1179,7 +685,7 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool if randomSuffix { // Trim if a hostname will be longer than 63 chars after adding the hash. - trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize + trimmedHostnameLength := util.LabelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize if len(normalizedHostname) > trimmedHostnameLength { normalizedHostname = normalizedHostname[:trimmedHostnameLength] } @@ -1221,16 +727,260 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string return givenName, nil } -func (machines Machines) FilterByIP(ip netip.Addr) Machines { - found := make(Machines, 0) +func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { + users, err := hsdb.ListUsers() + if err != nil { + log.Error().Err(err).Msg("Error listing users") + + return + } - for _, machine := range machines { - for _, mIP := range machine.IPAddresses { - if ip == mIP { - found = append(found, machine) + for _, user := range users { + machines, err := hsdb.ListMachinesByUser(user.Name) + if err != nil { + log.Error(). + Err(err). + Str("user", user.Name). + Msg("Error listing machines in user") + + return + } + + expiredFound := false + for idx, machine := range machines { + if machine.IsEphemeral() && machine.LastSeen != nil && + time.Now(). + After(machine.LastSeen.Add(inactivityThreshhold)) { + expiredFound = true + log.Info(). + Str("machine", machine.Hostname). + Msg("Ephemeral client removed from database") + + err = hsdb.HardDeleteMachine(&machines[idx]) + if err != nil { + log.Error(). + Err(err). + Str("machine", machine.Hostname). + Msg("🤮 Cannot delete ephemeral machine from the database") + } + } + } + + if expiredFound { + hsdb.notifyStateChange() + } + } +} + +func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) { + users, err := hsdb.ListUsers() + if err != nil { + log.Error().Err(err).Msg("Error listing users") + + return + } + + for _, user := range users { + machines, err := hsdb.ListMachinesByUser(user.Name) + if err != nil { + log.Error(). + Err(err). + Str("user", user.Name). + Msg("Error listing machines in user") + + return + } + + expiredFound := false + for index, machine := range machines { + if machine.IsExpired() && + machine.Expiry.After(lastChange) { + expiredFound = true + + err := hsdb.ExpireMachine(&machines[index]) + if err != nil { + log.Error(). + Err(err). + Str("machine", machine.Hostname). + Str("name", machine.GivenName). + Msg("🤮 Cannot expire machine") + } else { + log.Info(). + Str("machine", machine.Hostname). + Str("name", machine.GivenName). + Msg("Machine successfully expired") + } } } + + if expiredFound { + hsdb.notifyStateChange() + } + } +} + +func (hsdb *HSDatabase) TailNodes( + machines types.Machines, + pol *policy.ACLPolicy, + dnsConfig *tailcfg.DNSConfig, +) ([]*tailcfg.Node, error) { + nodes := make([]*tailcfg.Node, len(machines)) + + for index, machine := range machines { + node, err := hsdb.TailNode(machine, pol, dnsConfig) + if err != nil { + return nil, err + } + + nodes[index] = node + } + + return nodes, nil +} + +// TailNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes +// as per the expected behaviour in the official SaaS. +func (hsdb *HSDatabase) TailNode( + machine types.Machine, + pol *policy.ACLPolicy, + dnsConfig *tailcfg.DNSConfig, +) (*tailcfg.Node, error) { + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey))) + if err != nil { + log.Trace(). + Caller(). + Str("node_key", machine.NodeKey). + Msgf("Failed to parse node public key from hex") + + return nil, fmt.Errorf("failed to parse node public key: %w", err) + } + + var machineKey key.MachinePublic + // MachineKey is only used in the legacy protocol + if machine.MachineKey != "" { + err = machineKey.UnmarshalText( + []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse machine public key: %w", err) + } + } + + var discoKey key.DiscoPublic + if machine.DiscoKey != "" { + err := discoKey.UnmarshalText( + []byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse disco public key: %w", err) + } + } else { + discoKey = key.DiscoPublic{} } - return found + addrs := []netip.Prefix{} + for _, machineAddress := range machine.IPAddresses { + ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen()) + addrs = append(addrs, ip) + } + + allowedIPs := append( + []netip.Prefix{}, + addrs...) // we append the node own IP, as it is required by the clients + + primaryRoutes, err := hsdb.GetMachinePrimaryRoutes(&machine) + if err != nil { + return nil, err + } + primaryPrefixes := primaryRoutes.Prefixes() + + machineRoutes, err := hsdb.GetMachineRoutes(&machine) + if err != nil { + return nil, err + } + for _, route := range machineRoutes { + if route.Enabled && (route.IsPrimary || route.IsExitRoute()) { + allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) + } + } + + var derp string + if machine.HostInfo.NetInfo != nil { + derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) + } else { + derp = "127.3.3.40:0" // Zero means disconnected or unknown. + } + + var keyExpiry time.Time + if machine.Expiry != nil { + keyExpiry = *machine.Expiry + } else { + keyExpiry = time.Time{} + } + + var hostname string + if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS + hostname = fmt.Sprintf( + "%s.%s.%s", + machine.GivenName, + machine.User.Name, + hsdb.baseDomain, + ) + if len(hostname) > MaxHostnameLength { + return nil, fmt.Errorf( + "hostname %q is too long it cannot except 255 ASCII chars: %w", + hostname, + ErrHostnameTooLong, + ) + } + } else { + hostname = machine.GivenName + } + + hostInfo := machine.GetHostInfo() + + online := machine.IsOnline() + + tags, _ := pol.GetTagsOfMachine(machine, hsdb.stripEmailDomain) + tags = lo.Uniq(append(tags, machine.ForcedTags...)) + + node := tailcfg.Node{ + ID: tailcfg.NodeID(machine.ID), // this is the actual ID + StableID: tailcfg.StableNodeID( + strconv.FormatUint(machine.ID, util.Base10), + ), // in headscale, unlike tailcontrol server, IDs are permanent + Name: hostname, + + User: tailcfg.UserID(machine.UserID), + + Key: nodeKey, + KeyExpiry: keyExpiry, + + Machine: machineKey, + DiscoKey: discoKey, + Addresses: addrs, + AllowedIPs: allowedIPs, + Endpoints: machine.Endpoints, + DERP: derp, + Hostinfo: hostInfo.View(), + Created: machine.CreatedAt, + + Tags: tags, + + PrimaryRoutes: primaryPrefixes, + + LastSeen: machine.LastSeen, + Online: &online, + KeepAlive: true, + MachineAuthorized: !machine.IsExpired(), + + Capabilities: []string{ + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityAdmin, + tailcfg.CapabilitySSH, + }, + } + + return &node, nil } diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go new file mode 100644 index 0000000000..f34f64d43f --- /dev/null +++ b/hscontrol/db/machine_test.go @@ -0,0 +1,797 @@ +package db + +import ( + "fmt" + "net/netip" + "regexp" + "strconv" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func (s *Suite) TestGetMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByID(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByNodeKey(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + machine := types.Machine{ + ID: 0, + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByNodeKey(nodeKey.Public()) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + oldNodeKey := key.NewNode() + + machineKey := key.NewMachine() + + machine := types.Machine{ + ID: 0, + MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + _, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) + c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestDeleteMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(1), + } + db.db.Save(&machine) + + err = db.DeleteMachine(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(user.Name, "testmachine") + c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestHardDeleteMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine3", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(1), + } + db.db.Save(&machine) + + err = db.HardDeleteMachine(&machine) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(user.Name, "testmachine3") + c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestListPeers(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + machine := types.Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + Hostname: "testmachine" + strconv.Itoa(index), + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + } + + machine0ByID, err := db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + peersOfMachine0, err := db.ListPeers(machine0ByID) + c.Assert(err, check.IsNil) + + c.Assert(len(peersOfMachine0), check.Equals, 9) + c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") + c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") + c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGetACLFilteredPeers(c *check.C) { + type base struct { + user *types.User + key *types.PreAuthKey + } + + stor := make([]base, 0) + + for _, name := range []string{"test", "admin"} { + user, err := db.CreateUser(name) + c.Assert(err, check.IsNil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + stor = append(stor, base{user, pak}) + } + + _, err := db.GetMachineByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + machine := types.Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), + }, + Hostname: "testmachine" + strconv.Itoa(index), + UserID: stor[index%2].user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(stor[index%2].key.ID), + } + db.db.Save(&machine) + } + + aclPolicy := &policy.ACLPolicy{ + Groups: map[string][]string{ + "group:test": {"admin"}, + }, + Hosts: map[string]netip.Prefix{}, + TagOwners: map[string][]string{}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"admin"}, + Destinations: []string{"*:*"}, + }, + { + Action: "accept", + Sources: []string{"test"}, + Destinations: []string{"test:*"}, + }, + }, + Tests: []policy.ACLTest{}, + } + + adminMachine, err := db.GetMachineByID(1) + c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) + c.Assert(err, check.IsNil) + + testMachine, err := db.GetMachineByID(2) + c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) + c.Assert(err, check.IsNil) + + machines, err := db.ListMachines() + c.Assert(err, check.IsNil) + + aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) + c.Assert(err, check.IsNil) + + peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) + peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) + + c.Log(peersOfTestMachine) + c.Assert(len(peersOfTestMachine), check.Equals, 9) + c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") + c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") + c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") + + c.Log(peersOfAdminMachine) + c.Assert(len(peersOfAdminMachine), check.Equals, 9) + c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") + c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") + c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestExpireMachine(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + Expiry: &time.Time{}, + } + db.db.Save(machine) + + machineFromDB, err := db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machineFromDB, check.NotNil) + + c.Assert(machineFromDB.IsExpired(), check.Equals, false) + + err = db.ExpireMachine(machineFromDB) + c.Assert(err, check.IsNil) + + c.Assert(machineFromDB.IsExpired(), check.Equals, true) + + c.Assert(channelUpdates, check.Equals, int32(1)) +} + +func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { + input := types.MachineAddresses([]netip.Addr{ + netip.MustParseAddr("192.0.2.1"), + netip.MustParseAddr("2001:db8::1"), + }) + serialized, err := input.Value() + c.Assert(err, check.IsNil) + if serial, ok := serialized.(string); ok { + c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") + } + + var deserialized types.MachineAddresses + err = deserialized.Scan(serialized) + c.Assert(err, check.IsNil) + + c.Assert(len(deserialized), check.Equals, len(input)) + for i := range deserialized { + c.Assert(deserialized[i], check.Equals, input[i]) + } + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestGenerateGivenName(c *check.C) { + user1, err := db.CreateUser("user-1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("user-1", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "machine-key-1", + NodeKey: "node-key-1", + DiscoKey: "disco-key-1", + Hostname: "hostname-1", + GivenName: "hostname-1", + UserID: user1.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2") + comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-2", comment) + + givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1") + comment = check.Commentf("Same user, same machine, same hostname, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-1", comment) + + givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") + comment = check.Commentf("Same user, unique machines, same hostname, conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) + + givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1") + comment = check.Commentf("Unique users, unique machines, same hostname, conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) + + c.Assert(channelUpdates, check.Equals, int32(0)) +} + +func (s *Suite) TestSetTags(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(machine) + + // assign simple tags + sTags := []string{"tag:test", "tag:foo"} + err = db.SetTags(machine, sTags) + c.Assert(err, check.IsNil) + machine, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags)) + + // assign duplicat tags, expect no errors but no doubles in DB + eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} + err = db.SetTags(machine, eTags) + c.Assert(err, check.IsNil) + machine, err = db.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert( + machine.ForcedTags, + check.DeepEquals, + types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + ) + + c.Assert(channelUpdates, check.Equals, int32(4)) +} + +func TestHeadscale_generateGivenName(t *testing.T) { + type args struct { + suppliedName string + randomSuffix bool + } + tests := []struct { + name string + db *HSDatabase + args args + want *regexp.Regexp + wantErr bool + }{ + { + name: "simple machine name generation", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "testmachine", + randomSuffix: false, + }, + want: regexp.MustCompile("^testmachine$"), + wantErr: false, + }, + { + name: "machine name with 53 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + randomSuffix: false, + }, + want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), + wantErr: false, + }, + { + name: "machine name with 63 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), + wantErr: false, + }, + { + name: "machine name with 64 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "machine name with 73 chars", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "machine name with random suffix", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "test", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), + wantErr: false, + }, + { + name: "machine name with 63 chars with random suffix", + db: &HSDatabase{ + stripEmailDomain: true, + }, + args: args{ + suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + if (err != nil) != tt.wantErr { + t.Errorf( + "Headscale.GenerateGivenName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + + if tt.want != nil && !tt.want.MatchString(got) { + t.Errorf( + "Headscale.GenerateGivenName() = %v, does not match %v", + tt.want, + got, + ) + } + + if len(got) > util.LabelHostnameLength { + t.Errorf( + "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", + got, + util.LabelHostnameLength, + ) + } + }) + } +} + +func (s *Suite) TestAutoApproveRoutes(c *check.C) { + acl := []byte(` +{ + "tagOwners": { + "tag:exit": ["test"], + }, + + "groups": { + "group:test": ["test"] + }, + + "acls": [ + {"action": "accept", "users": ["*"], "ports": ["*:*"]}, + ], + + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test"], + } + } +} + `) + + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) + + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + nodeKey := key.NewNode() + + defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") + defaultRouteV6 := netip.MustParsePrefix("::/0") + route1 := netip.MustParsePrefix("10.10.0.0/16") + // Check if a subprefix of an autoapproved route is approved + route2 := netip.MustParsePrefix("10.11.0.0/24") + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:exit"}, + RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, + }, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + } + + db.db.Save(&machine) + + err = db.ProcessMachineRoutes(&machine) + c.Assert(err, check.IsNil) + + machine0ByID, err := db.GetMachineByID(0) + c.Assert(err, check.IsNil) + + err = db.EnableAutoApprovedRoutes(pol, machine0ByID) + c.Assert(err, check.IsNil) + + enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) + c.Assert(err, check.IsNil) + c.Assert(enabledRoutes, check.HasLen, 4) + + c.Assert(channelUpdates, check.Equals, int32(4)) +} + +func TestMachine_canAccess(t *testing.T) { + type args struct { + filter []tailcfg.FilterRule + machine2 *types.Machine + } + tests := []struct { + name string + machine types.Machine + args args + want bool + }{ + { + name: "no-rules", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{}, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: false, + }, + { + name: "wildcard", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: true, + }, + { + name: "explicit-m1-to-m2", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"10.0.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.0.0.2", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: true, + }, + { + name: "explicit-m2-to-m1", + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + }, + }, + args: args{ + filter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"10.0.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.0.0.1", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + machine2: &types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want { + t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/hscontrol/preauth_keys.go b/hscontrol/db/preauth_keys.go similarity index 59% rename from hscontrol/preauth_keys.go rename to hscontrol/db/preauth_keys.go index 1956762270..abb79c34c2 100644 --- a/hscontrol/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -1,17 +1,14 @@ -package hscontrol +package db import ( "crypto/rand" "encoding/hex" "errors" "fmt" - "strconv" "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/util" - "google.golang.org/protobuf/types/known/timestamppb" + "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" ) @@ -23,28 +20,6 @@ var ( ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) -// PreAuthKey describes a pre-authorization key usable in a particular user. -type PreAuthKey struct { - ID uint64 `gorm:"primary_key"` - Key string - UserID uint - User User - Reusable bool - Ephemeral bool `gorm:"default:false"` - Used bool `gorm:"default:false"` - ACLTags []PreAuthKeyACLTag - - CreatedAt *time.Time - Expiration *time.Time -} - -// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. -type PreAuthKeyACLTag struct { - ID uint64 `gorm:"primary_key"` - PreAuthKeyID uint64 - Tag string -} - // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( userName string, @@ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( ephemeral bool, expiration *time.Time, aclTags []string, -) (*PreAuthKey, error) { +) (*types.PreAuthKey, error) { user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( return nil, err } - key := PreAuthKey{ + key := types.PreAuthKey{ Key: kstr, UserID: user.ID, User: *user, @@ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( for _, tag := range aclTags { if !seenTags[tag] { - if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { return fmt.Errorf( "failed to ceate key tag in the database: %w", err, @@ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } // ListPreAuthKeys returns the list of PreAuthKeys for a user. -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { user, err := hsdb.GetUser(userName) if err != nil { return nil, err } - keys := []PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + keys := []types.PreAuthKey{} + if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) { } // GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) { - pak, err := hsdb.checkKeyValidity(key) +func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { + pak, err := hsdb.ValidatePreAuthKey(key) if err != nil { return nil, err } @@ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { +func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { return hsdb.db.Transaction(func(db *gorm.DB) error { - if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil { + if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error { } // MarkExpirePreAuthKey marks a PreAuthKey as expired. -func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { +func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { k.Used = true if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) @@ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error { return nil } -// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node +// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { - pak := PreAuthKey{} +func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + pak := types.PreAuthKey{} if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, @@ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) { return &pak, nil } - machines := []Machine{} - if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + machines := types.Machines{} + if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { return nil, err } @@ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) { return hex.EncodeToString(bytes), nil } - -func (key *PreAuthKey) toProto() *v1.PreAuthKey { - protoKey := v1.PreAuthKey{ - User: key.User.Name, - Id: strconv.FormatUint(key.ID, util.Base10), - Key: key.Key, - Ephemeral: key.Ephemeral, - Reusable: key.Reusable, - Used: key.Used, - AclTags: make([]string, len(key.ACLTags)), - } - - if key.Expiration != nil { - protoKey.Expiration = timestamppb.New(*key.Expiration) - } - - if key.CreatedAt != nil { - protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) - } - - for idx := range key.ACLTags { - protoKey.AclTags[idx] = key.ACLTags[idx].Tag - } - - return &protoKey -} diff --git a/hscontrol/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go similarity index 57% rename from hscontrol/preauth_keys_test.go rename to hscontrol/db/preauth_keys_test.go index a85a6c6103..e4a9773a03 100644 --- a/hscontrol/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -1,20 +1,22 @@ -package hscontrol +package db import ( "time" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil) + _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) c.Assert(err, check.NotNil) - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { // Make sure the User association is populated c.Assert(key.User.Name, check.Equals, user.Name) - _, err = app.db.ListPreAuthKeys("bogus") + _, err = db.ListPreAuthKeys("bogus") c.Assert(err, check.NotNil) - keys, err := app.db.ListPreAuthKeys(user.Name) + keys, err := db.ListPreAuthKeys(user.Name) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) @@ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { } func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := app.db.CreateUser("test2") + user, err := db.CreateUser("test2") c.Assert(err, check.IsNil) now := time.Now() - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := app.db.checkKeyValidity("potatoKey") + key, err := db.ValidatePreAuthKey("potatoKey") c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) c.Assert(key, check.IsNil) } func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := app.db.CreateUser("test3") + user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := app.db.CreateUser("test4") + user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(key, check.IsNil) } func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := app.db.CreateUser("test5") + user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - machine := Machine{ + machine := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := app.db.CreateUser("test6") + user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestEphemeralKey(c *check.C) { - user, err := app.db.CreateUser("test7") + user, err := db.CreateUser("test7") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) c.Assert(err, check.IsNil) - now := time.Now() - machine := Machine{ + now := time.Now().Add(-time.Second * 30) + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "testest", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, LastSeen: &now, AuthKeyID: uint(pak.ID), } - app.db.db.Save(&machine) + db.db.Save(&machine) - _, err = app.db.checkKeyValidity(pak.Key) + _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test7", "testest") + _, err = db.GetMachine("test7", "testest") c.Assert(err, check.IsNil) - app.expireEphemeralNodesWorker() + db.ExpireEphemeralMachines(time.Second * 20) // The machine record should have been deleted - _, err = app.db.GetMachine("test7", "testest") + _, err = db.GetMachine("test7", "testest") c.Assert(err, check.NotNil) + + c.Assert(channelUpdates, check.Equals, int32(1)) } func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := app.db.CreateUser("test3") + user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) - err = app.db.ExpirePreAuthKey(pak) + err = db.ExpirePreAuthKey(pak) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.NotNil) - key, err := app.db.checkKeyValidity(pak.Key) + key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrPreAuthKeyExpired) c.Assert(key, check.IsNil) } func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := app.db.CreateUser("test6") + user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - app.db.db.Save(&pak) + db.db.Save(&pak) - _, err = app.db.checkKeyValidity(pak.Key) + _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) } func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := app.db.CreateUser("test8") + user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := app.db.ListPreAuthKeys("test8") + listedPaks, err := db.ListPreAuthKeys("test8") c.Assert(err, check.IsNil) - c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) + c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags) } diff --git a/hscontrol/routes.go b/hscontrol/db/routes.go similarity index 62% rename from hscontrol/routes.go rename to hscontrol/db/routes.go index e3be2f691a..bdb3f4c523 100644 --- a/hscontrol/routes.go +++ b/hscontrol/db/routes.go @@ -1,56 +1,33 @@ -package hscontrol +package db import ( "errors" - "fmt" "net/netip" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) -var ( - ErrRouteIsNotAvailable = errors.New("route is not available") - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") -) - -type Route struct { - gorm.Model - - MachineID uint64 - Machine Machine - Prefix IPPrefix - - Advertised bool - Enabled bool - IsPrimary bool -} - -type Routes []Route +var ErrRouteIsNotAvailable = errors.New("route is not available") -func (r *Route) String() string { - return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) -} - -func (r *Route) isExitRoute() bool { - return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 -} - -func (rs Routes) toPrefixes() []netip.Prefix { - prefixes := make([]netip.Prefix, len(rs)) - for i, r := range rs { - prefixes[i] = netip.Prefix(r.Prefix) +func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { + var routes types.Routes + err := hsdb.db.Preload("Machine").Find(&routes).Error + if err != nil { + return nil, err } - return prefixes + return routes, nil } -func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { - var routes []Route - err := hsdb.db.Preload("Machine").Find(&routes).Error +func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { + var routes types.Routes + err := hsdb.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = true", machine.ID). + Find(&routes).Error if err != nil { return nil, err } @@ -58,8 +35,8 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { - var routes []Route +func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("machine_id = ?", m.ID). @@ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) { - var route Route +func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { + var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { return nil, err @@ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if route.isExitRoute() { - return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String()) + if route.IsExitRoute() { + return hsdb.enableRoutes( + &route.Machine, + types.ExitRouteV4.String(), + types.ExitRouteV6.String(), + ) } return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String()) @@ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.isExitRoute() { + if !route.IsExitRoute() { route.Enabled = false route.IsPrimary = false err = hsdb.db.Save(route).Error @@ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } routes, err := hsdb.GetMachineRoutes(&route.Machine) @@ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } for i := range routes { - if routes[i].isExitRoute() { + if routes[i].IsExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false err = hsdb.db.Save(&routes[i]).Error @@ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { @@ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.isExitRoute() { + if !route.IsExitRoute() { if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } routes, err := hsdb.GetMachineRoutes(&route.Machine) @@ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - routesToDelete := []Route{} + routesToDelete := types.Routes{} for _, r := range routes { - if r.isExitRoute() { + if r.IsExitRoute() { routesToDelete = append(routesToDelete, r) } } @@ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } -func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { +func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { routes, err := hsdb.GetMachineRoutes(m) if err != nil { return err @@ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error { } } - return hsdb.handlePrimarySubnetFailover() + return hsdb.HandlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { +func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { var count int64 hsdb.db. - Model(&Route{}). + Model(&types.Route{}). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", route.Prefix, route.MachineID, @@ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool { return count == 0 } -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { - var route Route +func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { + var route types.Route err := hsdb.db. Preload("Machine"). - Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). + Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). First(&route).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err @@ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { - var routes []Route +func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). @@ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { return routes, nil } -func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { - currentRoutes := []Route{} +func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { + currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { return err @@ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { for prefix, exists := range advertisedRoutes { if !exists { - route := Route{ + route := types.Route{ MachineID: machine.ID, - Prefix: IPPrefix(prefix), + Prefix: types.IPPrefix(prefix), Advertised: true, Enabled: false, } @@ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error { return nil } -func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { +func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { // first, get all the enabled routes - var routes []Route + var routes types.Routes err := hsdb.db. Preload("Machine"). Where("advertised = ? AND enabled = ?", true, true). @@ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { routesChanged := false for pos, route := range routes { - if route.isExitRoute() { + if route.IsExitRoute() { continue } @@ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { } if route.IsPrimary { - if route.Machine.isOnline() { + if route.Machine.IsOnline() { continue } @@ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { Msgf("machine offline, finding a new primary subnet") // find a new primary route - var newPrimaryRoutes []Route + var newPrimaryRoutes types.Routes err := hsdb.db. Preload("Machine"). Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", @@ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return err } - var newPrimaryRoute *Route + var newPrimaryRoute *types.Route for pos, r := range newPrimaryRoutes { - if r.Machine.isOnline() { + if r.Machine.IsOnline() { newPrimaryRoute = &newPrimaryRoutes[pos] break @@ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { return nil } -func (rs Routes) toProto() []*v1.Route { - protoRoutes := []*v1.Route{} - - for _, route := range rs { - protoRoute := v1.Route{ - Id: uint64(route.ID), - Machine: route.Machine.toProto(), - Prefix: netip.Prefix(route.Prefix).String(), - Advertised: route.Advertised, - Enabled: route.Enabled, - IsPrimary: route.IsPrimary, - CreatedAt: timestamppb.New(route.CreatedAt), - UpdatedAt: timestamppb.New(route.UpdatedAt), +// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. +func (hsdb *HSDatabase) EnableAutoApprovedRoutes( + aclPolicy *policy.ACLPolicy, + machine *types.Machine, +) error { + if len(machine.IPAddresses) == 0 { + return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs + } + + routes, err := hsdb.GetMachineAdvertisedRoutes(machine) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") + + return err + } + + approvedRoutes := types.Routes{} + + for _, advertisedRoute := range routes { + if advertisedRoute.Enabled { + continue } - if route.DeletedAt.Valid { - protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) + routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( + netip.Prefix(advertisedRoute.Prefix), + ) + if err != nil { + log.Err(err). + Str("advertisedRoute", advertisedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to resolve autoApprovers for advertised route") + + return err + } + + for _, approvedAlias := range routeApprovers { + if approvedAlias == machine.User.Name { + approvedRoutes = append(approvedRoutes, advertisedRoute) + } else { + // TODO(kradalby): figure out how to get this to depend on less stuff + approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) + if err != nil { + log.Err(err). + Str("alias", approvedAlias). + Msg("Failed to expand alias when processing autoApprovers policy") + + return err + } + + // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first + if approvedIps.Contains(machine.IPAddresses[0]) { + approvedRoutes = append(approvedRoutes, advertisedRoute) + } + } } + } - protoRoutes = append(protoRoutes, &protoRoute) + for _, approvedRoute := range approvedRoutes { + err := hsdb.EnableRoute(uint64(approvedRoute.ID)) + if err != nil { + log.Err(err). + Str("approvedRoute", approvedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to enable approved route") + + return err + } } - return protoRoutes + return nil } diff --git a/hscontrol/routes_test.go b/hscontrol/db/routes_test.go similarity index 60% rename from hscontrol/routes_test.go rename to hscontrol/db/routes_test.go index cf437a4d20..d281452dbe 100644 --- a/hscontrol/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -1,9 +1,11 @@ -package hscontrol +package db import ( "net/netip" "time" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" @@ -11,13 +13,13 @@ import ( ) func (s *Suite) TestGetRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_get_route_machine") + _, err = db.GetMachine("test", "test_get_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_get_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), + HostInfo: types.HostInfo(hostInfo), } - app.db.db.Save(&machine) + db.db.Save(&machine) - err = app.db.processMachineRoutes(&machine) + err = db.ProcessMachineRoutes(&machine) c.Assert(err, check.IsNil) - advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine) + advertisedRoutes, err := db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = app.db.enableRoutes(&machine, "192.168.0.0/24") + err = db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) + + c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetEnableRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } - machine := Machine{ + machine := types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), + HostInfo: types.HostInfo(hostInfo), } - app.db.db.Save(&machine) + db.db.Save(&machine) - err = app.db.processMachineRoutes(&machine) + err = db.ProcessMachineRoutes(&machine) c.Assert(err, check.IsNil) - availableRoutes, err := app.db.GetAdvertisedRoutes(&machine) + availableRoutes, err := db.GetAdvertisedRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(err, check.IsNil) c.Assert(len(availableRoutes), check.Equals, 2) - noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine) + noEnabledRoutes, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = app.db.enableRoutes(&machine, "192.168.0.0/24") + err = db.enableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes, err := app.db.GetEnabledRoutes(&machine) + enabledRoutes, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = app.db.enableRoutes(&machine, "10.0.0.0/24") + err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine) + enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = app.db.enableRoutes(&machine, "150.0.10.0/25") + err = db.enableRoutes(&machine, "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine) + enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestIsUniquePrefix(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, route.String()) + err = db.enableRoutes(&machine1, route.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, route2.String()) + err = db.enableRoutes(&machine1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route2}, } - machine2 := Machine{ + machine2 := types.Machine{ ID: 2, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), + HostInfo: types.HostInfo(hostInfo2), } - app.db.db.Save(&machine2) + db.db.Save(&machine2) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, route2.String()) + err = db.enableRoutes(&machine2, route2.String()) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) + enabledRoutes2, err := db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.db.getMachinePrimaryRoutes(&machine1) + routes, err := db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) + + c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestSubnetFailover(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix2.String()) + err = db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - route, err := app.db.getPrimaryRoute(prefix) + route, err := db.getPrimaryRoute(prefix) c.Assert(err, check.IsNil) c.Assert(route.MachineID, check.Equals, machine1.ID) hostInfo2 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix2}, } - machine2 := Machine{ + machine2 := types.Machine{ ID: 2, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), + HostInfo: types.HostInfo(hostInfo2), LastSeen: &now, } - app.db.db.Save(&machine2) + db.db.Save(&machine2) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, prefix2.String()) + err = db.enableRoutes(&machine2, prefix2.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err = db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 2) - enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2) + enabledRoutes2, err := db.GetEnabledRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes2), check.Equals, 1) - routes, err := app.db.getMachinePrimaryRoutes(&machine1) + routes, err := db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) // lets make machine1 lastseen 10 mins ago before := now.Add(-10 * time.Minute) machine1.LastSeen = &before - err = app.db.db.Save(&machine1).Error + err = db.db.Save(&machine1).Error c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.db.getMachinePrimaryRoutes(&machine1) + routes, err = db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 1) - machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ + machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{prefix, prefix2}, }) - err = app.db.db.Save(&machine2).Error + err = db.db.Save(&machine2).Error c.Assert(err, check.IsNil) - err = app.db.processMachineRoutes(&machine2) + err = db.ProcessMachineRoutes(&machine2) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine2, prefix.String()) + err = db.enableRoutes(&machine2, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - routes, err = app.db.getMachinePrimaryRoutes(&machine1) + routes, err = db.GetMachinePrimaryRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) - routes, err = app.db.getMachinePrimaryRoutes(&machine2) + routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(6)) } // TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, // including both the primary routes the node is responsible for, and the // exit node routes if enabled. func (s *Suite) TestAllowedIPRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { machineKey := key.NewMachine() now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) // We do not enable this one on purpose to test that it is not enabled - // err = app.db.enableRoutes(&machine1, prefix2.String()) + // err = db.enableRoutes(&machine1, prefix2.String()) // c.Assert(err, check.IsNil) - routes, err := app.db.GetMachineRoutes(&machine1) + routes, err := db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) for _, route := range routes { - if route.isExitRoute() { - err = app.db.EnableRoute(uint64(route.ID)) + if route.IsExitRoute() { + err = db.EnableRoute(uint64(route.ID)) c.Assert(err, check.IsNil) // We only enable one exit route, so we can test that both are enabled @@ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { } } - err = app.db.handlePrimarySubnetFailover() + err = db.HandlePrimarySubnetFailover() c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 3) - peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil) + peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil) c.Assert(err, check.IsNil) c.Assert(len(peer.AllowedIPs), check.Equals, 3) @@ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { // Now we disable only one of the exit routes // and we see if both are disabled - var exitRouteV4 Route + var exitRouteV4 types.Route for _, route := range routes { - if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { + if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { exitRouteV4 = route break } } - err = app.db.DisableRoute(uint64(exitRouteV4.ID)) + err = db.DisableRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err = db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) // and now we delete only one of the exit routes // and we check if both are deleted - routes, err = app.db.GetMachineRoutes(&machine1) + routes, err = db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 4) - err = app.db.DeleteRoute(uint64(exitRouteV4.ID)) + err = db.DeleteRoute(uint64(exitRouteV4.ID)) c.Assert(err, check.IsNil) - routes, err = app.db.GetMachineRoutes(&machine1) + routes, err = db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) + + c.Assert(channelUpdates, check.Equals, int32(2)) } func (s *Suite) TestDeleteRoutes(c *check.C) { - user, err := app.db.CreateUser("test") + user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = app.db.GetMachine("test", "test_enable_route_machine") + _, err = db.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() - machine1 := Machine{ + machine1 := types.Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Hostname: "test_enable_route_machine", UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo1), + HostInfo: types.HostInfo(hostInfo1), LastSeen: &now, } - app.db.db.Save(&machine1) + db.db.Save(&machine1) - err = app.db.processMachineRoutes(&machine1) + err = db.ProcessMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix.String()) + err = db.enableRoutes(&machine1, prefix.String()) c.Assert(err, check.IsNil) - err = app.db.enableRoutes(&machine1, prefix2.String()) + err = db.enableRoutes(&machine1, prefix2.String()) c.Assert(err, check.IsNil) - routes, err := app.db.GetMachineRoutes(&machine1) + routes, err := db.GetMachineRoutes(&machine1) c.Assert(err, check.IsNil) - err = app.db.DeleteRoute(uint64(routes[0].ID)) + err = db.DeleteRoute(uint64(routes[0].ID)) c.Assert(err, check.IsNil) - enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1) + enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) + + c.Assert(channelUpdates, check.Equals, int32(2)) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go new file mode 100644 index 0000000000..01541b9efe --- /dev/null +++ b/hscontrol/db/suite_test.go @@ -0,0 +1,74 @@ +package db + +import ( + "net/netip" + "os" + "sync/atomic" + "testing" + + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + +var ( + tmpDir string + db *HSDatabase + + // channelUpdates counts the number of times + // either of the channels was notified. + channelUpdates int32 +) + +func (s *Suite) SetUpTest(c *check.C) { + atomic.StoreInt32(&channelUpdates, 0) + s.ResetDB(c) +} + +func (s *Suite) TearDownTest(c *check.C) { + os.RemoveAll(tmpDir) +} + +func notificationSink(c <-chan struct{}) { + for { + <-c + atomic.AddInt32(&channelUpdates, 1) + } +} + +func (s *Suite) ResetDB(c *check.C) { + if len(tmpDir) != 0 { + os.RemoveAll(tmpDir) + } + var err error + tmpDir, err = os.MkdirTemp("", "autoygg-client-test") + if err != nil { + c.Fatal(err) + } + + sink := make(chan struct{}) + + go notificationSink(sink) + + db, err = NewHeadscaleDatabase( + "sqlite3", + tmpDir+"/headscale_test.db", + false, + false, + sink, + sink, + []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }, + "", + ) + if err != nil { + c.Fatal(err) + } +} diff --git a/hscontrol/users.go b/hscontrol/db/users.go similarity index 50% rename from hscontrol/users.go rename to hscontrol/db/users.go index fb3cea9c15..e0ffd19f11 100644 --- a/hscontrol/users.go +++ b/hscontrol/db/users.go @@ -1,17 +1,12 @@ -package hscontrol +package db import ( "errors" "fmt" - "regexp" - "strconv" - "strings" - "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -20,33 +15,16 @@ var ( ErrUserExists = errors.New("user already exists") ErrUserNotFound = errors.New("user not found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found") - ErrInvalidUserName = errors.New("invalid user name") ) -const ( - // value related to RFC 1123 and 952. - labelHostnameLength = 63 -) - -var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") - -// User is the way Headscale implements the concept of users in Tailscale -// -// At the end of the day, users in Tailscale are some kind of 'bubbles' or users -// that contain our machines. -type User struct { - gorm.Model - Name string `gorm:"unique"` -} - // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*User, error) { - err := CheckForFQDNRules(name) +func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user := User{} + user := types.User{} if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } @@ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - err = CheckForFQDNRules(newName) + err = util.CheckForFQDNRules(newName) if err != nil { return err } @@ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { } // GetUser fetches a user by name. -func (hsdb *HSDatabase) GetUser(name string) (*User, error) { - user := User{} +func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { + user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, @@ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) { } // ListUsers gets all the existing users. -func (hsdb *HSDatabase) ListUsers() ([]User, error) { - users := []User{} +func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { + users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err } @@ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) { } // ListMachinesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { - err := CheckForFQDNRules(name) +func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { + err := util.CheckForFQDNRules(name) if err != nil { return nil, err } @@ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { return nil, err } - machines := []Machine{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { + machines := types.Machines{} + if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil { return nil, err } @@ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) { } // SetMachineUser assigns a Machine to a user. -func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error { - err := CheckForFQDNRules(username) +func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { + err := util.CheckForFQDNRules(username) if err != nil { return err } @@ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error return nil } -func (n *User) toTailscaleUser() *tailcfg.User { - user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", - Domain: "headscale.net", - Logins: []tailcfg.LoginID{}, - Created: time.Time{}, - } - - return &user -} - -func (n *User) toTailscaleLogin() *tailcfg.Login { - login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", - Domain: "headscale.net", - } - - return &login -} - -func (hsdb *HSDatabase) getMapResponseUserProfiles( - machine Machine, - peers Machines, +func (hsdb *HSDatabase) GetMapResponseUserProfiles( + machine types.Machine, + peers types.Machines, ) []tailcfg.UserProfile { - userMap := make(map[string]User) + userMap := make(map[string]types.User) userMap[machine.User.Name] = machine.User for _, peer := range peers { userMap[peer.User.Name] = peer.User // not worth checking if already is there @@ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles( return profiles } - -func (n *User) toProto() *v1.User { - return &v1.User{ - Id: strconv.FormatUint(uint64(n.ID), util.Base10), - Name: n.Name, - CreatedAt: timestamppb.New(n.CreatedAt), - } -} - -// NormalizeToFQDNRules will replace forbidden chars in user -// it can also return an error if the user doesn't respect RFC 952 and 1123. -func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { - name = strings.ToLower(name) - name = strings.ReplaceAll(name, "'", "") - atIdx := strings.Index(name, "@") - if stripEmailDomain && atIdx > 0 { - name = name[:atIdx] - } else { - name = strings.ReplaceAll(name, "@", ".") - } - name = invalidCharsInUserRegex.ReplaceAllString(name, "-") - - for _, elt := range strings.Split(name, ".") { - if len(elt) > labelHostnameLength { - return "", fmt.Errorf( - "label %v is more than 63 chars: %w", - elt, - ErrInvalidUserName, - ) - } - } - - return name, nil -} - -func CheckForFQDNRules(name string) error { - if len(name) > labelHostnameLength { - return fmt.Errorf( - "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", - name, - ErrInvalidUserName, - ) - } - if strings.ToLower(name) != name { - return fmt.Errorf( - "DNS segment should be lowercase. %v doesn't comply with this rule: %w", - name, - ErrInvalidUserName, - ) - } - if invalidCharsInUserRegex.MatchString(name) { - return fmt.Errorf( - "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", - name, - ErrInvalidUserName, - ) - } - - return nil -} diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go new file mode 100644 index 0000000000..02c0a2ad0b --- /dev/null +++ b/hscontrol/db/users_test.go @@ -0,0 +1,277 @@ +package db + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gopkg.in/check.v1" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func (s *Suite) TestCreateAndDestroyUser(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + c.Assert(user.Name, check.Equals, "test") + + users, err := db.ListUsers() + c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) + + err = db.DestroyUser("test") + c.Assert(err, check.IsNil) + + _, err = db.GetUser("test") + c.Assert(err, check.NotNil) +} + +func (s *Suite) TestDestroyUserErrors(c *check.C) { + err := db.DestroyUser("test") + c.Assert(err, check.Equals, ErrUserNotFound) + + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + err = db.DestroyUser("test") + c.Assert(err, check.IsNil) + + result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) + // destroying a user also deletes all associated preauthkeys + c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) + + user, err = db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + + err = db.DestroyUser("test") + c.Assert(err, check.Equals, ErrUserStillHasNodes) +} + +func (s *Suite) TestRenameUser(c *check.C) { + userTest, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + c.Assert(userTest.Name, check.Equals, "test") + + users, err := db.ListUsers() + c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) + + err = db.RenameUser("test", "test-renamed") + c.Assert(err, check.IsNil) + + _, err = db.GetUser("test") + c.Assert(err, check.Equals, ErrUserNotFound) + + _, err = db.GetUser("test-renamed") + c.Assert(err, check.IsNil) + + err = db.RenameUser("test-does-not-exit", "test") + c.Assert(err, check.Equals, ErrUserNotFound) + + userTest2, err := db.CreateUser("test2") + c.Assert(err, check.IsNil) + c.Assert(userTest2.Name, check.Equals, "test2") + + err = db.RenameUser("test2", "test-renamed") + c.Assert(err, check.Equals, ErrUserExists) +} + +func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { + userShared1, err := db.CreateUser("shared1") + c.Assert(err, check.IsNil) + + userShared2, err := db.CreateUser("shared2") + c.Assert(err, check.IsNil) + + userShared3, err := db.CreateUser("shared3") + c.Assert(err, check.IsNil) + + preAuthKeyShared1, err := db.CreatePreAuthKey( + userShared1.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKeyShared2, err := db.CreatePreAuthKey( + userShared2.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKeyShared3, err := db.CreatePreAuthKey( + userShared3.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + preAuthKey2Shared1, err := db.CreatePreAuthKey( + userShared1.Name, + false, + false, + nil, + nil, + ) + c.Assert(err, check.IsNil) + + _, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") + c.Assert(err, check.NotNil) + + machineInShared1 := &types.Machine{ + ID: 1, + MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + Hostname: "test_get_shared_nodes_1", + UserID: userShared1.ID, + User: *userShared1, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + AuthKeyID: uint(preAuthKeyShared1.ID), + } + db.db.Save(machineInShared1) + + _, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname) + c.Assert(err, check.IsNil) + + machineInShared2 := &types.Machine{ + ID: 2, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_2", + UserID: userShared2.ID, + User: *userShared2, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, + AuthKeyID: uint(preAuthKeyShared2.ID), + } + db.db.Save(machineInShared2) + + _, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname) + c.Assert(err, check.IsNil) + + machineInShared3 := &types.Machine{ + ID: 3, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_3", + UserID: userShared3.ID, + User: *userShared3, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, + AuthKeyID: uint(preAuthKeyShared3.ID), + } + db.db.Save(machineInShared3) + + _, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname) + c.Assert(err, check.IsNil) + + machine2InShared1 := &types.Machine{ + ID: 4, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Hostname: "test_get_shared_nodes_4", + UserID: userShared1.ID, + User: *userShared1, + RegisterMethod: util.RegisterMethodAuthKey, + IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, + AuthKeyID: uint(preAuthKey2Shared1.ID), + } + db.db.Save(machine2InShared1) + + peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1) + c.Assert(err, check.IsNil) + + userProfiles := db.GetMapResponseUserProfiles( + *machineInShared1, + peersOfMachine1InShared1, + ) + + c.Assert(len(userProfiles), check.Equals, 3) + + found := false + for _, userProfiles := range userProfiles { + if userProfiles.DisplayName == userShared1.Name { + found = true + + break + } + } + c.Assert(found, check.Equals, true) + + found = false + for _, userProfile := range userProfiles { + if userProfile.DisplayName == userShared2.Name { + found = true + + break + } + } + c.Assert(found, check.Equals, true) +} + +func (s *Suite) TestSetMachineUser(c *check.C) { + oldUser, err := db.CreateUser("old") + c.Assert(err, check.IsNil) + + newUser, err := db.CreateUser("new") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + machine := types.Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + UserID: oldUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + db.db.Save(&machine) + c.Assert(machine.UserID, check.Equals, oldUser.ID) + + err = db.SetMachineUser(&machine, newUser.Name) + c.Assert(err, check.IsNil) + c.Assert(machine.UserID, check.Equals, newUser.ID) + c.Assert(machine.User.Name, check.Equals, newUser.Name) + + err = db.SetMachineUser(&machine, "non-existing-user") + c.Assert(err, check.Equals, ErrUserNotFound) + + err = db.SetMachineUser(&machine, newUser.Name) + c.Assert(err, check.IsNil) + c.Assert(machine.UserID, check.Equals, newUser.ID) + c.Assert(machine.User.Name, check.Equals, newUser.Name) +} diff --git a/hscontrol/dns.go b/hscontrol/dns.go index 72c5b03c0b..2c611f1b0e 100644 --- a/hscontrol/dns.go +++ b/hscontrol/dns.go @@ -7,6 +7,7 @@ import ( "strings" mapset "github.com/deckarep/golang-set/v2" + "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ @@ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { func getMapResponseDNSConfig( dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, - machine Machine, - peers Machines, + machine types.Machine, + peers types.Machines, ) *tailcfg.DNSConfig { var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled @@ -200,7 +201,7 @@ func getMapResponseDNSConfig( ), ) - userSet := mapset.NewSet[User]() + userSet := mapset.NewSet[types.User]() userSet.Add(machine.User) for _, p := range peers { userSet.Add(p.User) diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go index 671a712f45..6bee0ea8b8 100644 --- a/hscontrol/dns_test.go +++ b/hscontrol/dns_test.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) - machineInShared1 := &Machine{ + machineInShared1 := &types.Machine{ ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", @@ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_1", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.db.Save(machineInShared1) + err = app.db.MachineSave(machineInShared1) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) - machineInShared2 := &Machine{ + machineInShared2 := &types.Machine{ ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_2", UserID: userShared2.ID, User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.db.Save(machineInShared2) + err = app.db.MachineSave(machineInShared2) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) - machineInShared3 := &Machine{ + machineInShared3 := &types.Machine{ ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_3", UserID: userShared3.ID, User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.db.Save(machineInShared3) + err = app.db.MachineSave(machineInShared3) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) - machine2InShared1 := &Machine{ + machine2InShared1 := &types.Machine{ ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_4", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(PreAuthKey2InShared1.ID), } - app.db.db.Save(machine2InShared1) + err = app.db.MachineSave(machine2InShared1) + c.Assert(err, check.IsNil) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { Proxied: true, } - peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) + peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( @@ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) - machineInShared1 := &Machine{ + machineInShared1 := &types.Machine{ ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", @@ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_1", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, AuthKeyID: uint(preAuthKeyInShared1.ID), } - app.db.db.Save(machineInShared1) + err = app.db.MachineSave(machineInShared1) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) c.Assert(err, check.IsNil) - machineInShared2 := &Machine{ + machineInShared2 := &types.Machine{ ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_2", UserID: userShared2.ID, User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, AuthKeyID: uint(preAuthKeyInShared2.ID), } - app.db.db.Save(machineInShared2) + err = app.db.MachineSave(machineInShared2) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) c.Assert(err, check.IsNil) - machineInShared3 := &Machine{ + machineInShared3 := &types.Machine{ ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_3", UserID: userShared3.ID, User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, AuthKeyID: uint(preAuthKeyInShared3.ID), } - app.db.db.Save(machineInShared3) + err = app.db.MachineSave(machineInShared3) + c.Assert(err, check.IsNil) _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) c.Assert(err, check.IsNil) - machine2InShared1 := &Machine{ + machine2InShared1 := &types.Machine{ ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Hostname: "test_get_shared_nodes_4", UserID: userShared1.ID, User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, AuthKeyID: uint(preAuthKey2InShared1.ID), } - app.db.db.Save(machine2InShared1) + err = app.db.MachineSave(machine2InShared1) + c.Assert(err, check.IsNil) baseDomain := "foobar.headscale.net" dnsConfigOrig := tailcfg.DNSConfig{ @@ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { Proxied: false, } - peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) + peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) c.Assert(err, check.IsNil) dnsConfig := getMapResponseDNSConfig( diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 4a26d08eb7..8adf871ce7 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,6 +8,7 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" @@ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser( return nil, err } - return &v1.GetUserResponse{User: user.toProto()}, nil + return &v1.GetUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) CreateUser( @@ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } - return &v1.CreateUserResponse{User: user.toProto()}, nil + return &v1.CreateUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) RenameUser( @@ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser( return nil, err } - return &v1.RenameUserResponse{User: user.toProto()}, nil + return &v1.RenameUserResponse{User: user.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( @@ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers( response := make([]*v1.User, len(users)) for index, user := range users { - response[index] = user.toProto() + response[index] = user.Proto() } log.Trace().Caller().Interface("users", response).Msg("") @@ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( return nil, err } - return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil + return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil } func (api headscaleV1APIServer) ExpirePreAuthKey( @@ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response := make([]*v1.PreAuthKey, len(preAuthKeys)) for index, key := range preAuthKeys { - response[index] = key.toProto() + response[index] = key.Proto() } return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine( request.GetKey(), request.GetUser(), nil, - RegisterMethodCLI, + util.RegisterMethodCLI, ) if err != nil { return nil, err } - return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil + return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) GetMachine( @@ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine( return nil, err } - return &v1.GetMachineResponse{Machine: machine.toProto()}, nil + return &v1.GetMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) SetTags( @@ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags( } } - err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules) + err = api.h.db.SetTags(machine, request.GetTags()) if err != nil { return &v1.SetTagsResponse{ Machine: nil, @@ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags( Strs("tags", request.GetTags()). Msg("Changing tags of machine") - return &v1.SetTagsResponse{Machine: machine.toProto()}, nil + return &v1.SetTagsResponse{Machine: machine.Proto()}, nil } func validateTag(tag string) error { @@ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine( Time("expiry", *machine.Expiry). Msg("machine expired") - return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil + return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) RenameMachine( @@ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine( Str("new_name", request.GetNewName()). Msg("machine renamed") - return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil + return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) ListMachines( @@ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines( response := make([]*v1.Machine, len(machines)) for index, machine := range machines { - response[index] = machine.toProto() + response[index] = machine.Proto() } return &v1.ListMachinesResponse{Machines: response}, nil @@ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines( response := make([]*v1.Machine, len(machines)) for index, machine := range machines { - m := machine.toProto() - validTags, invalidTags := getTags( - api.h.aclPolicy, + m := machine.Proto() + validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( machine, api.h.cfg.OIDC.StripEmaildomain, ) @@ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine( return nil, err } - return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil + return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil } func (api headscaleV1APIServer) GetRoutes( @@ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes( } return &v1.GetRoutesResponse{ - Routes: Routes(routes).toProto(), + Routes: types.Routes(routes).Proto(), }, nil } @@ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes( } return &v1.GetMachineRoutesResponse{ - Routes: Routes(routes).toProto(), + Routes: types.Routes(routes).Proto(), }, nil } @@ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey( ctx context.Context, request *v1.ExpireApiKeyRequest, ) (*v1.ExpireApiKeyResponse, error) { - var apiKey *APIKey + var apiKey *types.APIKey var err error apiKey, err = api.h.db.GetAPIKey(request.Prefix) @@ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys( response := make([]*v1.ApiKey, len(apiKeys)) for index, key := range apiKeys { - response[index] = key.toProto() + response[index] = key.Proto() } return &v1.ListApiKeysResponse{ApiKeys: response}, nil @@ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( return nil, err } - newMachine := Machine{ + newMachine := types.Machine{ MachineKey: request.GetKey(), Hostname: request.GetName(), GivenName: givenName, @@ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( LastSeen: &time.Time{}, LastSuccessfulUpdate: &time.Time{}, - HostInfo: HostInfo(hostinfo), + HostInfo: types.HostInfo(hostinfo), } nodeKey := key.NodePublic{} @@ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine( registerCacheExpiration, ) - return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil + return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil } func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/machine_test.go b/hscontrol/machine_test.go deleted file mode 100644 index 0e7d7dea6a..0000000000 --- a/hscontrol/machine_test.go +++ /dev/null @@ -1,1386 +0,0 @@ -package hscontrol - -import ( - "fmt" - "net/netip" - "reflect" - "regexp" - "strconv" - "testing" - "time" - - "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func (s *Suite) TestGetMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByID(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByNodeKey(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machine := Machine{ - ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - oldNodeKey := key.NewNode() - - machineKey := key.NewMachine() - - machine := Machine{ - ID: 0, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - _, err = app.db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - app.db.db.Save(&machine) - - err = app.db.DeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(user.Name, "testmachine") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestHardDeleteMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine3", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - app.db.db.Save(&machine) - - err = app.db.HardDeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(user.Name, "testmachine3") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestListPeers(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - machine := Machine{ - ID: uint64(index), - MachineKey: "foo" + strconv.Itoa(index), - NodeKey: "bar" + strconv.Itoa(index), - DiscoKey: "faa" + strconv.Itoa(index), - Hostname: "testmachine" + strconv.Itoa(index), - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - } - - machine0ByID, err := app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) - - peersOfMachine0, err := app.db.ListPeers(machine0ByID) - c.Assert(err, check.IsNil) - - c.Assert(len(peersOfMachine0), check.Equals, 9) - c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") - c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") - c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") -} - -func (s *Suite) TestGetACLFilteredPeers(c *check.C) { - type base struct { - user *User - key *PreAuthKey - } - - stor := make([]base, 0) - - for _, name := range []string{"test", "admin"} { - user, err := app.db.CreateUser(name) - c.Assert(err, check.IsNil) - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - stor = append(stor, base{user, pak}) - } - - _, err := app.db.GetMachineByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - machine := Machine{ - ID: uint64(index), - MachineKey: "foo" + strconv.Itoa(index), - NodeKey: "bar" + strconv.Itoa(index), - DiscoKey: "faa" + strconv.Itoa(index), - IPAddresses: MachineAddresses{ - netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), - }, - Hostname: "testmachine" + strconv.Itoa(index), - UserID: stor[index%2].user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(stor[index%2].key.ID), - } - app.db.db.Save(&machine) - } - - app.aclPolicy = &ACLPolicy{ - Groups: map[string][]string{ - "group:test": {"admin"}, - }, - Hosts: map[string]netip.Prefix{}, - TagOwners: map[string][]string{}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"admin"}, - Destinations: []string{"*:*"}, - }, - { - Action: "accept", - Sources: []string{"test"}, - Destinations: []string{"test:*"}, - }, - }, - Tests: []ACLTest{}, - } - - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - - adminMachine, err := app.db.GetMachineByID(1) - c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User) - c.Assert(err, check.IsNil) - - testMachine, err := app.db.GetMachineByID(2) - c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User) - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - peersOfTestMachine := app.db.filterMachinesByACL(app.aclRules, testMachine, machines) - peersOfAdminMachine := app.db.filterMachinesByACL(app.aclRules, adminMachine, machines) - - c.Log(peersOfTestMachine) - c.Assert(len(peersOfTestMachine), check.Equals, 9) - c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1") - c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3") - c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5") - - c.Log(peersOfAdminMachine) - c.Assert(len(peersOfAdminMachine), check.Equals, 9) - c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") - c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") - c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") -} - -func (s *Suite) TestExpireMachine(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Expiry: &time.Time{}, - } - app.db.db.Save(machine) - - machineFromDB, err := app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert(machineFromDB, check.NotNil) - - c.Assert(machineFromDB.isExpired(), check.Equals, false) - - err = app.db.ExpireMachine(machineFromDB) - c.Assert(err, check.IsNil) - - c.Assert(machineFromDB.isExpired(), check.Equals, true) -} - -func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { - input := MachineAddresses([]netip.Addr{ - netip.MustParseAddr("192.0.2.1"), - netip.MustParseAddr("2001:db8::1"), - }) - serialized, err := input.Value() - c.Assert(err, check.IsNil) - if serial, ok := serialized.(string); ok { - c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") - } - - var deserialized MachineAddresses - err = deserialized.Scan(serialized) - c.Assert(err, check.IsNil) - - c.Assert(len(deserialized), check.Equals, len(input)) - for i := range deserialized { - c.Assert(deserialized[i], check.Equals, input[i]) - } -} - -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := app.db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user-1", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "machine-key-1", - NodeKey: "node-key-1", - DiscoKey: "disco-key-1", - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - givenName, err := app.db.GenerateGivenName("machine-key-2", "hostname-2") - comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = app.db.GenerateGivenName("machine-key-1", "hostname-1") - comment = check.Commentf("Same user, same machine, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") - comment = check.Commentf("Same user, unique machines, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) - - givenName, err = app.db.GenerateGivenName("machine-key-2", "hostname-1") - comment = check.Commentf("Unique users, unique machines, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) -} - -func (s *Suite) TestSetTags(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.NotNil) - - machine := &Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(machine) - - // assign simple tags - sTags := []string{"tag:test", "tag:foo"} - err = app.db.SetTags(machine, sTags, app.UpdateACLRules) - c.Assert(err, check.IsNil) - machine, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) - - // assign duplicat tags, expect no errors but no doubles in DB - eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = app.db.SetTags(machine, eTags, app.UpdateACLRules) - c.Assert(err, check.IsNil) - machine, err = app.db.GetMachine("test", "testmachine") - c.Assert(err, check.IsNil) - c.Assert( - machine.ForcedTags, - check.DeepEquals, - StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), - ) -} - -func Test_getTags(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - machine Machine - stripEmailDomain bool - } - tests := []struct { - name string - args args - wantInvalid []string - wantValid []string - }{ - { - name: "valid tag one machine", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:valid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: nil, - }, - { - name: "invalid tag and valid tag one machine", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:valid", "tag:invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "multiple invalid and identical tags, should return only one invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{ - "tag:invalid", - "tag:valid", - "tag:invalid", - }, - }, - }, - stripEmailDomain: false, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "only invalid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - { - name: "empty ACLPolicy should return empty tags and should not panic", - args: args{ - aclPolicy: nil, - machine: Machine{ - User: User{ - Name: "joe", - }, - HostInfo: HostInfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - stripEmailDomain: false, - }, - wantValid: nil, - wantInvalid: nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotValid, gotInvalid := getTags( - test.args.aclPolicy, - test.args.machine, - test.args.stripEmailDomain, - ) - for _, valid := range gotValid { - if !util.StringOrPrefixListContains(test.wantValid, valid) { - t.Errorf( - "valids: getTags() = %v, want %v", - gotValid, - test.wantValid, - ) - - break - } - } - for _, invalid := range gotInvalid { - if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { - t.Errorf( - "invalids: getTags() = %v, want %v", - gotInvalid, - test.wantInvalid, - ) - - break - } - } - }) - } -} - -func Test_getFilteredByACLPeers(t *testing.T) { - type args struct { - machines []Machine - rules []tailcfg.FilterRule - machine *Machine - } - tests := []struct { - name string - args args - want Machines - }{ - { - name: "all hosts can talk to each other", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "One host can talk to another, but not all hosts", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - }, - { - name: "host cannot directly talk to destination, but return path is authorized", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - }, - want: Machines{ - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination, destination can reach all hosts", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "rule allows all hosts to reach all destinations", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{ - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")}, - User: User{Name: "mickael"}, - }, - }, - }, - { - name: "without rule all communications are forbidden", - args: args{ - machines: []Machine{ // list of all machines in the database - { - ID: 1, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: User{Name: "joe"}, - }, - { - ID: 2, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: User{Name: "marc"}, - }, - { - ID: 3, - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - }, - machine: &Machine{ // current machine - ID: 2, - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - User: User{Name: "marc"}, - }, - }, - want: Machines{}, - }, - { - // Investigating 699 - // Found some machines: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa - // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] - // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} - name: "issue-699-broken-star", - args: args{ - machines: Machines{ // - { - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: User{Name: "user1"}, - }, - { - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: User{Name: "user1"}, - }, - { - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: User{Name: "user2"}, - }, - { - ID: 4, - Hostname: "ts-unstable-lys2ib", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("fd7a:115c:a1e0::2"), - }, - User: User{Name: "user2"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - DstPorts: []tailcfg.NetPortRange{ - { - IP: "*", - Ports: tailcfg.PortRange{First: 0, Last: 65535}, - }, - }, - SrcIPs: []string{ - "fd7a:115c:a1e0::3", "100.64.0.3", - "fd7a:115c:a1e0::4", "100.64.0.4", - }, - }, - }, - machine: &Machine{ // current machine - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: User{Name: "user2"}, - }, - }, - want: Machines{ - { - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: User{Name: "user1"}, - }, - { - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: MachineAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: User{Name: "user1"}, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := filterMachinesByACL( - tt.args.machine, - tt.args.machines, - tt.args.rules, - ) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestHeadscale_generateGivenName(t *testing.T) { - type args struct { - suppliedName string - randomSuffix bool - } - tests := []struct { - name string - db *HSDatabase - args args - want *regexp.Regexp - wantErr bool - }{ - { - name: "simple machine name generation", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "testmachine", - randomSuffix: false, - }, - want: regexp.MustCompile("^testmachine$"), - wantErr: false, - }, - { - name: "machine name with 53 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", - randomSuffix: false, - }, - want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), - wantErr: false, - }, - { - name: "machine name with 63 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"), - wantErr: false, - }, - { - name: "machine name with 64 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "machine name with 73 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "machine name with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "test", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)), - wantErr: false, - }, - { - name: "machine name with 63 chars with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, - args: args{ - suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) - if (err != nil) != tt.wantErr { - t.Errorf( - "Headscale.GenerateGivenName() error = %v, wantErr %v", - err, - tt.wantErr, - ) - - return - } - - if tt.want != nil && !tt.want.MatchString(got) { - t.Errorf( - "Headscale.GenerateGivenName() = %v, does not match %v", - tt.want, - got, - ) - } - - if len(got) > labelHostnameLength { - t.Errorf( - "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", - got, - labelHostnameLength, - ) - } - }) - } -} - -func (s *Suite) TestAutoApproveRoutes(c *check.C) { - acl := []byte(` -{ - "tagOwners": { - "tag:exit": ["test"], - }, - - "groups": { - "group:test": ["test"] - }, - - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], - - "autoApprovers": { - "exitNode": ["tag:exit"], - "routes": { - "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], - } - } -} - `) - - err := app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - nodeKey := key.NewNode() - - defaultRoute := netip.MustParsePrefix("0.0.0.0/0") - route1 := netip.MustParsePrefix("10.10.0.0/16") - // Check if a subprefix of an autoapproved route is approved - route2 := netip.MustParsePrefix("10.11.0.0/24") - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: "faa", - Hostname: "test", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo{ - RequestTags: []string{"tag:exit"}, - RoutableIPs: []netip.Prefix{defaultRoute, route1, route2}, - }, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - } - - app.db.db.Save(&machine) - - err = app.db.processMachineRoutes(&machine) - c.Assert(err, check.IsNil) - - machine0ByID, err := app.db.GetMachineByID(0) - c.Assert(err, check.IsNil) - - err = app.db.EnableAutoApprovedRoutes(app.aclPolicy, machine0ByID) - c.Assert(err, check.IsNil) - - enabledRoutes, err := app.db.GetEnabledRoutes(machine0ByID) - c.Assert(err, check.IsNil) - c.Assert(enabledRoutes, check.HasLen, 3) -} - -func TestMachine_canAccess(t *testing.T) { - type args struct { - filter []tailcfg.FilterRule - machine2 *Machine - } - tests := []struct { - name string - machine Machine - args args - want bool - }{ - { - name: "no-rules", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{}, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: false, - }, - { - name: "wildcard", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "*", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: true, - }, - { - name: "explicit-m1-to-m2", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"10.0.0.1"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "10.0.0.2", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: true, - }, - { - name: "explicit-m2-to-m1", - machine: Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - }, - args: args{ - filter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"10.0.0.2"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "10.0.0.1", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - machine2: &Machine{ - IPAddresses: MachineAddresses{ - netip.MustParseAddr("10.0.0.2"), - }, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.machine.canAccess(tt.args.filter, tt.args.machine2); got != tt.want { - t.Errorf("Machine.canAccess() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/hscontrol/matcher.go b/hscontrol/matcher.go deleted file mode 100644 index 3b4670e8d4..0000000000 --- a/hscontrol/matcher.go +++ /dev/null @@ -1,142 +0,0 @@ -package hscontrol - -import ( - "fmt" - "net/netip" - "strings" - - "go4.org/netipx" - "tailscale.com/tailcfg" -) - -// This is borrowed from, and updated to use IPSet -// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 -// TODO(kradalby): contribute upstream and make public. -var ( - zeroIP4 = netip.AddrFrom4([4]byte{}) - zeroIP6 = netip.AddrFrom16([16]byte{}) -) - -// parseIPSet parses arg as one: -// -// - an IP address (IPv4 or IPv6) -// - the string "*" to match everything (both IPv4 & IPv6) -// - a CIDR (e.g. "192.168.0.0/16") -// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") -// -// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP -// address (without a slash) treated as a CIDR of *bits length. -// nolint -func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) { - var ipSet netipx.IPSetBuilder - if arg == "*" { - ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) - ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) - - return ipSet.IPSet() - } - if strings.Contains(arg, "/") { - pfx, err := netip.ParsePrefix(arg) - if err != nil { - return nil, err - } - if pfx != pfx.Masked() { - return nil, fmt.Errorf("%v contains non-network bits set", pfx) - } - - ipSet.AddPrefix(pfx) - - return ipSet.IPSet() - } - if strings.Count(arg, "-") == 1 { - ip1s, ip2s, _ := strings.Cut(arg, "-") - - ip1, err := netip.ParseAddr(ip1s) - if err != nil { - return nil, err - } - - ip2, err := netip.ParseAddr(ip2s) - if err != nil { - return nil, err - } - - r := netipx.IPRangeFrom(ip1, ip2) - if !r.IsValid() { - return nil, fmt.Errorf("invalid IP range %q", arg) - } - - for _, prefix := range r.Prefixes() { - ipSet.AddPrefix(prefix) - } - - return ipSet.IPSet() - } - ip, err := netip.ParseAddr(arg) - if err != nil { - return nil, fmt.Errorf("invalid IP address %q", arg) - } - bits8 := uint8(ip.BitLen()) - if bits != nil { - if *bits < 0 || *bits > int(bits8) { - return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) - } - bits8 = uint8(*bits) - } - - ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) - - return ipSet.IPSet() -} - -type Match struct { - Srcs *netipx.IPSet - Dests *netipx.IPSet -} - -func MatchFromFilterRule(rule tailcfg.FilterRule) Match { - srcs := new(netipx.IPSetBuilder) - dests := new(netipx.IPSetBuilder) - - for _, srcIP := range rule.SrcIPs { - set, _ := parseIPSet(srcIP, nil) - - srcs.AddSet(set) - } - - for _, dest := range rule.DstPorts { - set, _ := parseIPSet(dest.IP, nil) - - dests.AddSet(set) - } - - srcsSet, _ := srcs.IPSet() - destsSet, _ := dests.IPSet() - - match := Match{ - Srcs: srcsSet, - Dests: destsSet, - } - - return match -} - -func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Srcs.Contains(ip) { - return true - } - } - - return false -} - -func (m *Match) DestsContainsIP(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Dests.Contains(ip) { - return true - } - } - - return false -} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index c666594e5e..4e68a22691 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -14,6 +14,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" @@ -638,7 +640,7 @@ func getUserName( claims *IDTokenClaims, stripEmaildomain bool, ) (string, error) { - userName, err := NormalizeToFQDNRules( + userName, err := util.NormalizeToFQDNRules( claims.Email, stripEmaildomain, ) @@ -663,9 +665,9 @@ func getUserName( func (h *Headscale) findOrCreateNewUserForOIDCCallback( writer http.ResponseWriter, userName string, -) (*User, error) { +) (*types.User, error) { user, err := h.db.GetUser(userName) - if errors.Is(err, ErrUserNotFound) { + if errors.Is(err, db.ErrUserNotFound) { user, err = h.db.CreateUser(userName) if err != nil { @@ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( func (h *Headscale) registerMachineForOIDCCallback( writer http.ResponseWriter, - user *User, + user *types.User, nodeKey *key.NodePublic, expiry time.Time, ) error { @@ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback( nodeKey.String(), user.Name, &expiry, - RegisterMethodOIDC, + util.RegisterMethodOIDC, ); err != nil { log.Error(). Caller(). diff --git a/hscontrol/acls.go b/hscontrol/policy/acls.go similarity index 79% rename from hscontrol/acls.go rename to hscontrol/policy/acls.go index 2c81046a49..6b42ebe779 100644 --- a/hscontrol/acls.go +++ b/hscontrol/policy/acls.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "encoding/json" @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/tailscale/hujson" @@ -22,12 +23,12 @@ import ( ) var ( - errEmptyPolicy = errors.New("empty policy") - errInvalidAction = errors.New("invalid action") - errInvalidGroup = errors.New("invalid group") - errInvalidTag = errors.New("invalid tag") - errInvalidPortFormat = errors.New("invalid port format") - errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") + ErrEmptyPolicy = errors.New("empty policy") + ErrInvalidAction = errors.New("invalid action") + ErrInvalidGroup = errors.New("invalid group") + ErrInvalidTag = errors.New("invalid tag") + ErrInvalidPortFormat = errors.New("invalid port format") + ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") ) const ( @@ -56,7 +57,7 @@ const ( var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH") // LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. -func (h *Headscale) LoadACLPolicyFromPath(path string) error { +func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { log.Debug(). Str("func", "LoadACLPolicy"). Str("path", path). @@ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { policyFile, err := os.Open(path) if err != nil { - return err + return nil, err } defer policyFile.Close() policyBytes, err := io.ReadAll(policyFile) if err != nil { - return err + return nil, err } log.Debug(). @@ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error { switch filepath.Ext(path) { case ".yml", ".yaml": - return h.LoadACLPolicyFromBytes(policyBytes, "yaml") + return LoadACLPolicyFromBytes(policyBytes, "yaml") } - return h.LoadACLPolicyFromBytes(policyBytes, "hujson") + return LoadACLPolicyFromBytes(policyBytes, "hujson") } -func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error { +func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { var policy ACLPolicy switch format { case "yaml": err := yaml.Unmarshal(acl, &policy) if err != nil { - return err + return nil, err } default: ast, err := hujson.Parse(acl) if err != nil { - return err + return nil, err } ast.Standardize() acl = ast.Pack() err = json.Unmarshal(acl, &policy) if err != nil { - return err + return nil, err } } if policy.IsZero() { - return errEmptyPolicy + return nil, ErrEmptyPolicy } - h.aclPolicy = &policy - - return h.UpdateACLRules() + return &policy, nil } -func (h *Headscale) UpdateACLRules() error { - machines, err := h.db.ListMachines() - if err != nil { - return err - } - - if h.aclPolicy == nil { - return errEmptyPolicy +// TODO(kradalby): This needs to be replace with something that generates +// the rules as needed and not stores it on the global object, rules are +// per node and that should be taken into account. +func GenerateFilterRules( + policy *ACLPolicy, + machines types.Machines, + stripEmailDomain bool, +) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { + if policy == nil { + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy } - rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain) + rules, err := policy.generateFilterRules(machines, stripEmailDomain) if err != nil { - return err + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Msg("ACL rules generated") - h.aclRules = rules + var sshPolicy *tailcfg.SSHPolicy if featureEnableSSH() { - sshRules, err := h.generateSSHRules() + sshRules, err := generateSSHRules(policy, machines, stripEmailDomain) if err != nil { - return err + return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated") - if h.sshPolicy == nil { - h.sshPolicy = &tailcfg.SSHPolicy{} + if sshPolicy == nil { + sshPolicy = &tailcfg.SSHPolicy{} } - h.sshPolicy.Rules = sshRules - } else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 { + sshPolicy.Rules = sshRules + } else if policy != nil && len(policy.SSHs) > 0 { log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating") } - return nil + return rules, sshPolicy, nil } // generateFilterRules takes a set of machines and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) generateFilterRules( - machines []Machine, + machines types.Machines, stripEmailDomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} for index, acl := range pol.ACLs { if acl.Action != "accept" { - return nil, errInvalidAction + return nil, ErrInvalidAction } srcIPs := []string{} @@ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules( return rules, nil } -func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { +func generateSSHRules( + policy *ACLPolicy, + machines types.Machines, + stripEmailDomain bool, +) ([]*tailcfg.SSHRule, error) { rules := []*tailcfg.SSHRule{} - if h.aclPolicy == nil { - return nil, errEmptyPolicy - } - - machines, err := h.db.ListMachines() - if err != nil { - return nil, err + if policy == nil { + return nil, ErrEmptyPolicy } acceptAction := tailcfg.SSHAction{ @@ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { AllowLocalPortForwarding: false, } - for index, sshACL := range h.aclPolicy.SSHs { + for index, sshACL := range policy.SSHs { action := rejectAction switch sshACL.Action { case "accept": @@ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { } default: log.Error(). - Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action) + Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action) - return nil, err + continue } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { Any: true, }) } else if isGroup(rawSrc) { - users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain) + users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain) if err != nil { log.Error(). Msgf("Error parsing SSH %d, Source %d", index, innerIndex) @@ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) { }) } } else { - expandedSrcs, err := h.aclPolicy.expandAlias( + expandedSrcs, err := policy.ExpandAlias( machines, rawSrc, - h.cfg.OIDC.StripEmaildomain, + stripEmailDomain, ) if err != nil { log.Error(). @@ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { // with the given src alias. func (pol *ACLPolicy) getIPsFromSource( src string, - machines []Machine, + machines types.Machines, stripEmaildomain bool, ) ([]string, error) { - ipSet, err := pol.expandAlias(machines, src, stripEmaildomain) + ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) if err != nil { return []string{}, err } @@ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource( // which are associated with the dest alias. func (pol *ACLPolicy) getNetPortRangeFromDestination( dest string, - machines []Machine, + machines types.Machines, needsWildcard bool, stripEmaildomain bool, ) ([]tailcfg.NetPortRange, error) { @@ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( return nil, fmt.Errorf( "failed to parse destination, tokens %v: %w", tokens, - errInvalidPortFormat, + ErrInvalidPortFormat, ) } else { tokens = []string{maybeIPv6Str, port} @@ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) } - expanded, err := pol.expandAlias( + expanded, err := pol.ExpandAlias( machines, alias, stripEmaildomain, @@ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) { // - an ip // - a cidr // and transform these in IPAddresses. -func (pol *ACLPolicy) expandAlias( - machines Machines, +func (pol *ACLPolicy) ExpandAlias( + machines types.Machines, alias string, stripEmailDomain bool, ) (*netipx.IPSet, error) { if isWildcard(alias) { - return parseIPSet("*", nil) + return util.ParseIPSet("*", nil) } build := netipx.IPSetBuilder{} @@ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias( // if alias is an host // Note, this is recursive. if h, ok := pol.Hosts[alias]; ok { - log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") + log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.expandAlias(machines, h.String(), stripEmailDomain) + return pol.ExpandAlias(machines, h.String(), stripEmailDomain) } // if alias is an IP @@ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias( // we assume in this function that we only have nodes from 1 user. func excludeCorrectlyTaggedNodes( aclPolicy *ACLPolicy, - nodes []Machine, + nodes types.Machines, user string, stripEmailDomain bool, -) []Machine { - out := []Machine{} +) types.Machines { + out := types.Machines{} tags := []string{} for tag := range aclPolicy.TagOwners { owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) @@ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err } if needsWildcard { - return nil, errWildcardIsNeeded + return nil, ErrWildcardIsNeeded } ports := []tailcfg.PortRange{} @@ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err }) default: - return nil, errInvalidPortFormat + return nil, ErrInvalidPortFormat } } return &ports, nil } -func filterMachinesByUser(machines []Machine, user string) []Machine { - out := []Machine{} +func filterMachinesByUser(machines types.Machines, user string) types.Machines { + out := types.Machines{} for _, machine := range machines { if machine.User.Name == user { out = append(out, machine) @@ -664,7 +664,7 @@ func getTagOwners( if !ok { return []string{}, fmt.Errorf( "%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", - errInvalidTag, + ErrInvalidTag, tag, ) } @@ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup( return []string{}, fmt.Errorf( "group %v isn't registered. %w", group, - errInvalidGroup, + ErrInvalidGroup, ) } for _, group := range aclGroups { if isGroup(group) { return []string{}, fmt.Errorf( "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", - errInvalidGroup, + ErrInvalidGroup, ) } - grp, err := NormalizeToFQDNRules(group, stripEmailDomain) + grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) if err != nil { return []string{}, fmt.Errorf( "failed to normalize group %q, err: %w", group, - errInvalidGroup, + ErrInvalidGroup, ) } users = append(users, grp) @@ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup( func (pol *ACLPolicy) getIPsFromGroup( group string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) getIPsFromTag( alias string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag( // find tag owners owners, err := getTagOwners(pol, alias, stripEmailDomain) if err != nil { - if errors.Is(err, errInvalidTag) { + if errors.Is(err, ErrInvalidTag) { ipSet, _ := build.IPSet() if len(ipSet.Prefixes()) == 0 { return ipSet, fmt.Errorf( "%w. %v isn't owned by a TagOwner and no forced tags are defined", - errInvalidTag, + ErrInvalidTag, alias, ) } @@ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsForUser( user string, - machines Machines, + machines types.Machines, stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) getIPsFromSingleIP( ip netip.Addr, - machines Machines, + machines types.Machines, ) (*netipx.IPSet, error) { - log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") + log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") matches := machines.FilterByIP(ip) @@ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP( func (pol *ACLPolicy) getIPsFromIPPrefix( prefix netip.Prefix, - machines Machines, + machines types.Machines, ) (*netipx.IPSet, error) { log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") build := netipx.IPSetBuilder{} @@ -862,3 +862,65 @@ func isGroup(str string) bool { func isTag(str string) bool { return strings.HasPrefix(str, "tag:") } + +// getTags will return the tags of the current machine. +// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. +// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. +func (pol *ACLPolicy) GetTagsOfMachine( + machine types.Machine, + stripEmailDomain bool, +) ([]string, []string) { + validTags := make([]string, 0) + invalidTags := make([]string, 0) + + validTagMap := make(map[string]bool) + invalidTagMap := make(map[string]bool) + for _, tag := range machine.HostInfo.RequestTags { + owners, err := getTagOwners(pol, tag, stripEmailDomain) + if errors.Is(err, ErrInvalidTag) { + invalidTagMap[tag] = true + + continue + } + var found bool + for _, owner := range owners { + if machine.User.Name == owner { + found = true + } + } + if found { + validTagMap[tag] = true + } else { + invalidTagMap[tag] = true + } + } + for tag := range invalidTagMap { + invalidTags = append(invalidTags, tag) + } + for tag := range validTagMap { + validTags = append(validTags, tag) + } + + return validTags, invalidTags +} + +// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. +func FilterMachinesByACL( + machine *types.Machine, + machines types.Machines, + filter []tailcfg.FilterRule, +) types.Machines { + result := types.Machines{} + + for index, peer := range machines { + if peer.ID == machine.ID { + continue + } + + if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) { + result = append(result, peer) + } + } + + return result +} diff --git a/hscontrol/acls_test.go b/hscontrol/policy/acls_test.go similarity index 56% rename from hscontrol/acls_test.go rename to hscontrol/policy/acls_test.go index 70a57b81ab..f6c5e10791 100644 --- a/hscontrol/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "errors" @@ -7,15 +7,24 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "go4.org/netipx" "gopkg.in/check.v1" - "tailscale.com/envknob" "tailscale.com/tailcfg" ) +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} + func (s *Suite) TestWrongPath(c *check.C) { - err := app.LoadACLPolicyFromPath("asdfg") + _, err := LoadACLPolicyFromPath("asdfg") c.Assert(err, check.NotNil) } @@ -23,7 +32,7 @@ func (s *Suite) TestBrokenHuJson(c *check.C) { acl := []byte(` { `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + _, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.NotNil) } @@ -34,9 +43,9 @@ func (s *Suite) TestInvalidPolicyHuson(c *check.C) { "but_a_policy_though": false } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + _, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.NotNil) - c.Assert(err, check.Equals, errEmptyPolicy) + c.Assert(err, check.Equals, ErrEmptyPolicy) } func (s *Suite) TestParseHosts(c *check.C) { @@ -185,8 +194,13 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(pol.ACLs, check.HasLen, 6) + c.Assert(err, check.IsNil) + + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.NotNil) + c.Assert(rules, check.IsNil) } func (s *Suite) TestBasicRule(c *check.C) { @@ -212,17 +226,17 @@ func (s *Suite) TestBasicRule(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } // TODO(kradalby): Make tests values safe, independent and descriptive. func (s *Suite) TestInvalidAction(c *check.C) { - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ ACLs: []ACL{ { Action: "invalidAction", @@ -231,88 +245,13 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidAction), check.Equals, true) -} - -func (s *Suite) TestSshRules(c *check.C) { - envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1") - - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1"}, - }, - Hosts: Hosts{ - "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - SSHs: []SSH{ - { - Action: "accept", - Sources: []string{"group:test"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - }, - } - - err = app.UpdateACLRules() - - c.Assert(err, check.IsNil) - c.Assert(app.sshPolicy, check.NotNil) - c.Assert(app.sshPolicy.Rules, check.HasLen, 2) - c.Assert(app.sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[0].Principals, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1") - - c.Assert(app.sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[1].Principals, check.HasLen, 1) - c.Assert(app.sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*") + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } func (s *Suite) TestInvalidGroupInGroup(c *check.C) { // this ACL is wrong because the group in Sources sections doesn't exist - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ Groups: Groups{ "group:test": []string{"foo"}, "group:error": []string{"foo", "group:test"}, @@ -325,13 +264,13 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidGroup), check.Equals, true) + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } func (s *Suite) TestInvalidTagOwners(c *check.C) { // this ACL is wrong because no tagOwners own the requested tag for the server - app.aclPolicy = &ACLPolicy{ + pol := &ACLPolicy{ ACLs: []ACL{ { Action: "accept", @@ -340,232 +279,9 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, }, } - err := app.UpdateACLRules() - c.Assert(errors.Is(err, errInvalidTag), check.Equals, true) -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Sources section. -func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:test"}, - Destinations: []string{"*:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Destinations section. -func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:test"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"tag:test:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].DstPorts, check.HasLen, 1) - c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") -} - -// need a test with: -// tag on a host that isn't owned by a tag owners. So the user -// of the host should be valid. -func (s *Suite) TestInvalidTagValidUser(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "testmachine") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testmachine", - RequestTags: []string{"tag:foo"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - - app.aclPolicy = &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"*:*"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32") -} - -// tag on a host is owned by a tag owner, the tag is valid. -// an ACL rule is matching the tag to a user. It should not be valid since the -// host should be tied to the tag now. -func (s *Suite) TestValidTagInvalidUser(c *check.C) { - user, err := app.db.CreateUser("user1") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("user1", "webserver") - c.Assert(err, check.NotNil) - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "webserver", - RequestTags: []string{"tag:webapp"}, - } - - machine := Machine{ - ID: 1, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "webserver", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo), - } - app.db.db.Save(&machine) - _, err = app.db.GetMachine("user1", "user") - hostInfo2 := tailcfg.Hostinfo{ - OS: "debian", - Hostname: "Hostname", - } - c.Assert(err, check.NotNil) - machine = Machine{ - ID: 2, - MachineKey: "56789", - NodeKey: "bar2", - DiscoKey: "faab", - Hostname: "user", - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")}, - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: HostInfo(hostInfo2), - } - app.db.db.Save(&machine) - app.aclPolicy = &ACLPolicy{ - TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"tag:webapp:80,443"}, - }, - }, - } - err = app.UpdateACLRules() - c.Assert(err, check.IsNil) - c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2/32") - c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2) - c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) - c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) - c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) - c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) - c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32") + _, _, err := GenerateFilterRules(pol, types.Machines{}, false) + c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } func (s *Suite) TestPortRange(c *check.C) { @@ -589,10 +305,11 @@ func (s *Suite) TestPortRange(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -644,10 +361,11 @@ func (s *Suite) TestProtocolParsing(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -678,10 +396,11 @@ func (s *Suite) TestPortWildcard(c *check.C) { ], } `) - err := app.LoadACLPolicyFromBytes(acl, "hujson") + pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -704,10 +423,11 @@ acls: - "*" dst: - host-1:*`) - err := app.LoadACLPolicyFromBytes(acl, "yaml") + pol, err := LoadACLPolicyFromBytes(acl, "yaml") c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) - rules, err := app.aclPolicy.generateFilterRules([]Machine{}, false) + rules, err := pol.generateFilterRules(types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -719,138 +439,6 @@ acls: c.Assert(rules[0].SrcIPs[0], check.Equals, "0.0.0.0/0") } -func (s *Suite) TestPortUser(c *check.C) { - user, err := app.db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := app.db.getAvailableIPs() - machine := Machine{ - ID: 0, - MachineKey: "12345", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - acl := []byte(` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} - `) - err = app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - rules, err := app.aclPolicy.generateFilterRules(machines, false) - c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") -} - -func (s *Suite) TestPortGroup(c *check.C) { - user, err := app.db.CreateUser("testuser") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine("testuser", "testmachine") - c.Assert(err, check.NotNil) - ips, _ := app.db.getAvailableIPs() - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: ips, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - acl := []byte(` -{ - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} - `) - err = app.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - - machines, err := app.db.ListMachines() - c.Assert(err, check.IsNil) - - rules, err := app.aclPolicy.generateFilterRules(machines, false) - c.Assert(err, check.IsNil) - c.Assert(rules, check.NotNil) - - c.Assert(rules, check.HasLen, 1) - c.Assert(rules[0].DstPorts, check.HasLen, 1) - c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) - c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) - c.Assert(rules[0].SrcIPs, check.HasLen, 1) - c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") - c.Assert(len(ips), check.Equals, 1) - c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32") -} - func Test_expandGroup(t *testing.T) { type field struct { pol ACLPolicy @@ -1151,54 +739,54 @@ func Test_expandPorts(t *testing.T) { func Test_listMachinesInUser(t *testing.T) { type args struct { - machines []Machine + machines types.Machines user string } tests := []struct { name string args args - want []Machine + want types.Machines }{ { name: "1 machine in user", args: args{ - machines: []Machine{ - {User: User{Name: "joe"}}, + machines: types.Machines{ + {User: types.User{Name: "joe"}}, }, user: "joe", }, - want: []Machine{ - {User: User{Name: "joe"}}, + want: types.Machines{ + {User: types.User{Name: "joe"}}, }, }, { name: "3 machines, 2 in user", args: args{ - machines: []Machine{ - {ID: 1, User: User{Name: "joe"}}, - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, + machines: types.Machines{ + {ID: 1, User: types.User{Name: "joe"}}, + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, }, user: "marc", }, - want: []Machine{ - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, + want: types.Machines{ + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, }, }, { name: "5 machines, 0 in user", args: args{ - machines: []Machine{ - {ID: 1, User: User{Name: "joe"}}, - {ID: 2, User: User{Name: "marc"}}, - {ID: 3, User: User{Name: "marc"}}, - {ID: 4, User: User{Name: "marc"}}, - {ID: 5, User: User{Name: "marc"}}, + machines: types.Machines{ + {ID: 1, User: types.User{Name: "joe"}}, + {ID: 2, User: types.User{Name: "marc"}}, + {ID: 3, User: types.User{Name: "marc"}}, + {ID: 4, User: types.User{Name: "marc"}}, + {ID: 5, User: types.User{Name: "marc"}}, }, user: "mickael", }, - want: []Machine{}, + want: types.Machines{}, }, } for _, test := range tests { @@ -1234,7 +822,7 @@ func Test_expandAlias(t *testing.T) { pol ACLPolicy } type args struct { - machines []Machine + machines types.Machines aclPolicy ACLPolicy alias string stripEmailDomain bool @@ -1253,10 +841,10 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "*", - machines: []Machine{ - {IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, + machines: types.Machines{ + {IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.78.84.227"), }, }, @@ -1278,30 +866,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "group:accountant", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1320,30 +908,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "group:hr", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1358,7 +946,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.3", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{ @@ -1373,7 +961,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{ @@ -1388,12 +976,12 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1410,13 +998,13 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.1", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1433,13 +1021,13 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1460,7 +1048,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "testy", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{}, []string{"10.0.0.132/32"}), @@ -1477,7 +1065,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "homeNetwork", - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: set([]string{}, []string{"192.168.1.0/24"}), @@ -1490,7 +1078,7 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "10.0.0.0/16", - machines: []Machine{}, + machines: types.Machines{}, aclPolicy: ACLPolicy{}, stripEmailDomain: true, }, @@ -1506,40 +1094,40 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, stripEmailDomain: true, @@ -1561,30 +1149,30 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1599,32 +1187,32 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1643,36 +1231,36 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "tag:hr-webserver", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, @@ -1689,40 +1277,40 @@ func Test_expandAlias(t *testing.T) { }, args: args{ alias: "joe", - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, - User: User{Name: "marc"}, + User: types.User{Name: "marc"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, stripEmailDomain: true, @@ -1733,7 +1321,7 @@ func Test_expandAlias(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.field.pol.expandAlias( + got, err := test.field.pol.ExpandAlias( test.args.machines, test.args.alias, test.args.stripEmailDomain, @@ -1753,14 +1341,14 @@ func Test_expandAlias(t *testing.T) { func Test_excludeCorrectlyTaggedNodes(t *testing.T) { type args struct { aclPolicy *ACLPolicy - nodes []Machine + nodes types.Machines user string stripEmailDomain bool } tests := []struct { name string args args - want []Machine + want types.Machines wantErr bool }{ { @@ -1769,43 +1357,43 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1820,43 +1408,43 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { "tag:accountant-webserver": []string{"group:accountant"}, }, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1866,39 +1454,39 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "foo", RequestTags: []string{"tag:accountant-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, ForcedTags: []string{"tag:accountant-webserver"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")}, - User: User{Name: "joe"}, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, + User: types.User{Name: "joe"}, }, }, }, @@ -1908,67 +1496,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, - nodes: []Machine{ + nodes: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, user: "joe", stripEmailDomain: true, }, - want: []Machine{ + want: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web1", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, - User: User{Name: "joe"}, - HostInfo: HostInfo{ + User: types.User{Name: "joe"}, + HostInfo: types.HostInfo{ OS: "centos", Hostname: "hr-web2", RequestTags: []string{"tag:hr-webserver"}, }, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, - User: User{Name: "joe"}, + User: types.User{Name: "joe"}, }, }, }, @@ -1993,7 +1581,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - machines []Machine + machines types.Machines stripEmailDomain bool } tests := []struct { @@ -2024,7 +1612,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: []Machine{}, + machines: types.Machines{}, stripEmailDomain: true, }, want: []tailcfg.FilterRule{ @@ -2064,27 +1652,30 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: []Machine{ + machines: types.Machines{ { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, { - IPAddresses: MachineAddresses{ + IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), }, - User: User{Name: "mickael"}, + User: types.User{Name: "mickael"}, }, }, stripEmailDomain: true, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.2/32", @@ -2113,14 +1704,631 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { tt.args.stripEmailDomain, ) if (err != nil) != tt.wantErr { - t.Errorf("ACLPolicy.generateFilterRules() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) return } if diff := cmp.Diff(tt.want, got); diff != "" { log.Trace().Interface("got", got).Msg("result") - t.Errorf("ACLPolicy.generateFilterRules() = %v, want %v", got, tt.want) + t.Errorf("ACLgenerateFilterRules() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getTags(t *testing.T) { + type args struct { + aclPolicy *ACLPolicy + machine types.Machine + stripEmailDomain bool + } + tests := []struct { + name string + args args + wantInvalid []string + wantValid []string + }{ + { + name: "valid tag one machine", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:valid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: nil, + }, + { + name: "invalid tag and valid tag one machine", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:valid", "tag:invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: []string{"tag:invalid"}, + }, + { + name: "multiple invalid and identical tags, should return only one invalid tag", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{ + "tag:invalid", + "tag:valid", + "tag:invalid", + }, + }, + }, + stripEmailDomain: false, + }, + wantValid: []string{"tag:valid"}, + wantInvalid: []string{"tag:invalid"}, + }, + { + name: "only invalid tags", + args: args{ + aclPolicy: &ACLPolicy{ + TagOwners: TagOwners{ + "tag:valid": []string{"joe"}, + }, + }, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:invalid", "very-invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: nil, + wantInvalid: []string{"tag:invalid", "very-invalid"}, + }, + { + name: "empty ACLPolicy should return empty tags and should not panic", + args: args{ + aclPolicy: &ACLPolicy{}, + machine: types.Machine{ + User: types.User{ + Name: "joe", + }, + HostInfo: types.HostInfo{ + RequestTags: []string{"tag:invalid", "very-invalid"}, + }, + }, + stripEmailDomain: false, + }, + wantValid: nil, + wantInvalid: []string{"tag:invalid", "very-invalid"}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( + test.args.machine, + test.args.stripEmailDomain, + ) + for _, valid := range gotValid { + if !util.StringOrPrefixListContains(test.wantValid, valid) { + t.Errorf( + "valids: getTags() = %v, want %v", + gotValid, + test.wantValid, + ) + + break + } + } + for _, invalid := range gotInvalid { + if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { + t.Errorf( + "invalids: getTags() = %v, want %v", + gotInvalid, + test.wantInvalid, + ) + + break + } + } + }) + } +} + +func Test_getFilteredByACLPeers(t *testing.T) { + type args struct { + machines types.Machines + rules []tailcfg.FilterRule + machine *types.Machine + } + tests := []struct { + name string + args args + want types.Machines + }{ + { + name: "all hosts can talk to each other", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "One host can talk to another, but not all hosts", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + }, + { + name: "host cannot directly talk to destination, but return path is authorized", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + }, + want: types.Machines{ + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination, destination can reach all hosts", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rule allows all hosts to reach all destinations", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{ + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, + User: types.User{Name: "mickael"}, + }, + }, + }, + { + name: "without rule all communications are forbidden", + args: args{ + machines: types.Machines{ // list of all machines in the database + { + ID: 1, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + }, + User: types.User{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + }, + User: types.User{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + }, + User: types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + }, + machine: &types.Machine{ // current machine + ID: 2, + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, + User: types.User{Name: "marc"}, + }, + }, + want: types.Machines{}, + }, + { + // Investigating 699 + // Found some machines: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] machine=ts-head-8w6paa + // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] + // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} + name: "issue-699-broken-star", + args: args{ + machines: types.Machines{ // + { + ID: 1, + Hostname: "ts-head-upcrmb", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + netip.MustParseAddr("fd7a:115c:a1e0::3"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.4"), + netip.MustParseAddr("fd7a:115c:a1e0::4"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 3, + Hostname: "ts-head-8w6paa", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user2"}, + }, + { + ID: 4, + Hostname: "ts-unstable-lys2ib", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.2"), + netip.MustParseAddr("fd7a:115c:a1e0::2"), + }, + User: types.User{Name: "user2"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{First: 0, Last: 65535}, + }, + }, + SrcIPs: []string{ + "fd7a:115c:a1e0::3", "100.64.0.3", + "fd7a:115c:a1e0::4", "100.64.0.4", + }, + }, + }, + machine: &types.Machine{ // current machine + ID: 3, + Hostname: "ts-head-8w6paa", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0::1"), + }, + User: types.User{Name: "user2"}, + }, + }, + want: types.Machines{ + { + ID: 1, + Hostname: "ts-head-upcrmb", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.3"), + netip.MustParseAddr("fd7a:115c:a1e0::3"), + }, + User: types.User{Name: "user1"}, + }, + { + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.4"), + netip.MustParseAddr("fd7a:115c:a1e0::4"), + }, + User: types.User{Name: "user1"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FilterMachinesByACL( + tt.args.machine, + tt.args.machines, + tt.args.rules, + ) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterMachinesByACL() = %v, want %v", got, tt.want) } }) } diff --git a/hscontrol/acls_types.go b/hscontrol/policy/acls_types.go similarity index 99% rename from hscontrol/acls_types.go rename to hscontrol/policy/acls_types.go index 0e55351503..e9c44909d1 100644 --- a/hscontrol/acls_types.go +++ b/hscontrol/policy/acls_types.go @@ -1,4 +1,4 @@ -package hscontrol +package policy import ( "encoding/json" diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go new file mode 100644 index 0000000000..8458339cd0 --- /dev/null +++ b/hscontrol/policy/matcher/matcher.go @@ -0,0 +1,61 @@ +package matcher + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/util" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +type Match struct { + Srcs *netipx.IPSet + Dests *netipx.IPSet +} + +func MatchFromFilterRule(rule tailcfg.FilterRule) Match { + srcs := new(netipx.IPSetBuilder) + dests := new(netipx.IPSetBuilder) + + for _, srcIP := range rule.SrcIPs { + set, _ := util.ParseIPSet(srcIP, nil) + + srcs.AddSet(set) + } + + for _, dest := range rule.DstPorts { + set, _ := util.ParseIPSet(dest.IP, nil) + + dests.AddSet(set) + } + + srcsSet, _ := srcs.IPSet() + destsSet, _ := dests.IPSet() + + match := Match{ + Srcs: srcsSet, + Dests: destsSet, + } + + return match +} + +func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Srcs.Contains(ip) { + return true + } + } + + return false +} + +func (m *Match) DestsContainsIP(ips []netip.Addr) bool { + for _, ip := range ips { + if m.Dests.Contains(ip) { + return true + } + } + + return false +} diff --git a/hscontrol/policy/matcher/matcher_test.go b/hscontrol/policy/matcher/matcher_test.go new file mode 100644 index 0000000000..54cf8a0643 --- /dev/null +++ b/hscontrol/policy/matcher/matcher_test.go @@ -0,0 +1 @@ +package matcher diff --git a/hscontrol/protocol_common.go b/hscontrol/protocol_common.go index 5cd0ddb4e8..ae034fb6f2 100644 --- a/hscontrol/protocol_common.go +++ b/hscontrol/protocol_common.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" @@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon( // that we rely on a method that calls back some how (OpenID or CLI) // We create the machine and then keep it around until a callback // happens - newMachine := Machine{ + newMachine := types.Machine{ MachineKey: util.MachinePublicKeyStripPrefix(machineKey), Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, @@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon( []byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)), ) if err != nil || storedMachineKey.IsZero() { - machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey) - if err := h.db.db.Save(&machine).Error; err != nil { + if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil { log.Error(). Caller(). Str("func", "RegistrationHandler"). @@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon( // If machine is not expired, and it is register, we have a already accepted this machine, // let it proceed with a valid registration - if !machine.isExpired() { + if !machine.IsExpired() { h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) return @@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon( // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && - !machine.isExpired() { + !machine.IsExpired() { h.handleMachineRefreshKeyCommon( writer, registerRequest, @@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon( Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey) + pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). @@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon( Err(err). Msg("Cannot encode message") http.Error(writer, "Internal server error", http.StatusInternalServerError) - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() return @@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon( Msg("Failed authentication via AuthKey") if pak != nil { - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() } else { - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() } return @@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon( return } - aclTags := pak.toProto().AclTags + aclTags := pak.Proto().AclTags if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(machine, aclTags, h.UpdateACLRules) + err = h.db.SetTags(machine, aclTags) if err != nil { log.Error(). @@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon( return } - machineToRegister := Machine{ + machineToRegister := types.Machine{ Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, MachineKey: util.MachinePublicKeyStripPrefix(machineKey), - RegisterMethod: RegisterMethodAuthKey, + RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, NodeKey: nodeKey, LastSeen: &now, AuthKeyID: uint(pak.ID), - ForcedTags: pak.toProto().AclTags, + ForcedTags: pak.Proto().AclTags, } machine, err = h.db.RegisterMachine( @@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon( Bool("noise", isNoise). Err(err). Msg("could not register machine") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) @@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon( Bool("noise", isNoise). Err(err). Msg("Failed to use pre-auth key") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) @@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon( } resp.MachineAuthorized = true - resp.User = *pak.User.toTailscaleUser() + resp.User = *pak.User.TailscaleUser() // Provide LoginName when registering with pre-auth key // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* - resp.Login = *pak.User.toTailscaleLogin() + resp.Login = *pak.User.TailscaleLogin() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { @@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon( Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). + machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). Inc() writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon( func (h *Headscale) handleMachineLogOutCommon( writer http.ResponseWriter, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon( resp.AuthURL = "" resp.MachineAuthorized = false resp.NodeKeyExpired = true - resp.User = *machine.User.toTailscaleUser() + resp.User = *machine.User.TailscaleUser() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { log.Error(). @@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon( return } - if machine.isEphemeral() { + if machine.IsEphemeral() { err = h.db.HardDeleteMachine(&machine) if err != nil { log.Error(). @@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleMachineValidRegistrationCommon( writer http.ResponseWriter, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon( resp.AuthURL = "" resp.MachineAuthorized = true - resp.User = *machine.User.toTailscaleUser() - resp.Login = *machine.User.toTailscaleLogin() + resp.User = *machine.User.TailscaleUser() + resp.Login = *machine.User.TailscaleLogin() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { @@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleMachineRefreshKeyCommon( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { @@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon( Bool("noise", isNoise). Str("machine", machine.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey) - if err := h.db.db.Save(&machine).Error; err != nil { + err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey) + if err != nil { log.Error(). Caller(). Err(err). @@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( } resp.AuthURL = "" - resp.User = *machine.User.toTailscaleUser() + resp.User = *machine.User.TailscaleUser() respBody, err := h.marshalResponse(resp, machineKey, isNoise) if err != nil { log.Error(). @@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machine Machine, + machine types.Machine, machineKey key.MachinePublic, isNoise bool, ) { diff --git a/hscontrol/protocol_common_poll.go b/hscontrol/protocol_common_poll.go index 502c633a8f..3d432387c1 100644 --- a/hscontrol/protocol_common_poll.go +++ b/hscontrol/protocol_common_poll.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" @@ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName") func (h *Headscale) handlePollCommon( writer http.ResponseWriter, ctx context.Context, - machine *Machine, + machine *types.Machine, mapRequest tailcfg.MapRequest, isNoise bool, ) { machine.Hostname = mapRequest.Hostinfo.Hostname - machine.HostInfo = HostInfo(*mapRequest.Hostinfo) + machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) now := time.Now().UTC() - err := h.db.processMachineRoutes(machine) + err := h.db.ProcessMachineRoutes(machine) if err != nil { log.Error(). Caller(). @@ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon( } // update ACLRules with peer informations (to update server tags if necessary) - if h.aclPolicy != nil { - err := h.UpdateACLRules() - if err != nil { - log.Error(). - Caller(). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Err(err) - } + if h.ACLPolicy != nil { + // TODO(kradalby): Since this is not blocking, I might have introduced a bug here. + // It will be resolved later as we change up the policy stuff. + h.policyUpdateChan <- struct{}{} // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine) + err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) if err != nil { log.Error(). Caller(). @@ -78,19 +74,17 @@ func (h *Headscale) handlePollCommon( machine.LastSeen = &now } - if err := h.db.db.Updates(machine).Error; err != nil { - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("node_key", machine.NodeKey). - Str("machine", machine.Hostname). - Err(err). - Msg("Failed to persist/update machine in the database") - http.Error(writer, "", http.StatusInternalServerError) + if err := h.db.MachineSave(machine); err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Bool("noise", isNoise). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + Err(err). + Msg("Failed to persist/update machine in the database") + http.Error(writer, "", http.StatusInternalServerError) - return - } + return } mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) @@ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon( func (h *Headscale) pollNetMapStream( writer http.ResponseWriter, ctxReq context.Context, - machine *Machine, + machine *types.Machine, mapRequest tailcfg.MapRequest, pollDataChan chan []byte, keepAliveChan chan []byte, @@ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream( updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). Inc() - if h.db.isOutdated(machine, h.getLastStateChange()) { + if h.db.IsOutdated(machine, h.getLastStateChange()) { var lastUpdate time.Time if machine.LastSuccessfulUpdate != nil { lastUpdate = *machine.LastSuccessfulUpdate @@ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker( updateChan chan struct{}, keepAliveChan chan []byte, mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) { keepAliveTicker := time.NewTicker(keepAliveInterval) diff --git a/hscontrol/protocol_common_utils.go b/hscontrol/protocol_common_utils.go index 1dababa1ab..8990eeb364 100644 --- a/hscontrol/protocol_common_utils.go +++ b/hscontrol/protocol_common_utils.go @@ -5,6 +5,7 @@ import ( "encoding/json" "sync" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" @@ -15,7 +16,7 @@ import ( func (h *Headscale) getMapResponseData( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) ([]byte, error) { mapResponse, err := h.generateMapResponse(mapRequest, machine) @@ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData( func (h *Headscale) getMapKeepAliveResponseData( mapRequest tailcfg.MapRequest, - machine *Machine, + machine *types.Machine, isNoise bool, ) ([]byte, error) { keepAliveResponse := tailcfg.MapResponse{ diff --git a/hscontrol/app_test.go b/hscontrol/suite_test.go similarity index 54% rename from hscontrol/app_test.go rename to hscontrol/suite_test.go index 1b4e91e827..69a651a8de 100644 --- a/hscontrol/app_test.go +++ b/hscontrol/suite_test.go @@ -18,7 +18,7 @@ type Suite struct{} var ( tmpDir string - app Headscale + app *Headscale ) func (s *Suite) SetUpTest(c *check.C) { @@ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) { os.RemoveAll(tmpDir) } var err error - tmpDir, err = os.MkdirTemp("", "autoygg-client-test") + tmpDir, err = os.MkdirTemp("", "autoygg-client-test2") if err != nil { c.Fatal(err) } cfg := Config{ + PrivateKeyPath: tmpDir + "/private.key", + NoisePrivateKeyPath: tmpDir + "/noise_private.key", + DBtype: "sqlite3", + DBpath: tmpDir + "/headscale_test.db", IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, @@ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) { }, } - // TODO(kradalby): make this use NewHeadscale properly so it doesnt drift - app = Headscale{ - cfg: &cfg, - dbType: "sqlite3", - dbString: tmpDir + "/headscale_test.db", - - stateUpdateChan: make(chan struct{}), - cancelStateUpdateChan: make(chan struct{}), - } - - go app.watchStateChannel() - - db, err := NewHeadscaleDatabase( - app.dbType, - app.dbString, - cfg.OIDC.StripEmaildomain, - false, - app.stateUpdateChan, - cfg.IPPrefixes, - "", - ) + app, err = NewHeadscale(&cfg) if err != nil { c.Fatal(err) } - app.db = db } diff --git a/hscontrol/types/api_key.go b/hscontrol/types/api_key.go new file mode 100644 index 0000000000..8ca0004494 --- /dev/null +++ b/hscontrol/types/api_key.go @@ -0,0 +1,41 @@ +package types + +import ( + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// APIKey describes the datamodel for API keys used to remotely authenticate with +// headscale. +type APIKey struct { + ID uint64 `gorm:"primary_key"` + Prefix string `gorm:"uniqueIndex"` + Hash []byte + + CreatedAt *time.Time + Expiration *time.Time + LastSeen *time.Time +} + +func (key *APIKey) Proto() *v1.ApiKey { + protoKey := v1.ApiKey{ + Id: key.ID, + Prefix: key.Prefix, + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + if key.LastSeen != nil { + protoKey.LastSeen = timestamppb.New(*key.LastSeen) + } + + return &protoKey +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go new file mode 100644 index 0000000000..96ad1b782e --- /dev/null +++ b/hscontrol/types/common.go @@ -0,0 +1,108 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "net/netip" + + "tailscale.com/tailcfg" +) + +var ErrCannotParsePrefix = errors.New("cannot parse prefix") + +// This is a "wrapper" type around tailscales +// Hostinfo to allow us to add database "serialization" +// methods. This allows us to use a typed values throughout +// the code and not have to marshal/unmarshal and error +// check all over the code. +type HostInfo tailcfg.Hostinfo + +func (hi *HostInfo) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, hi) + + case string: + return json.Unmarshal([]byte(value), hi) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (hi HostInfo) Value() (driver.Value, error) { + bytes, err := json.Marshal(hi) + + return string(bytes), err +} + +type IPPrefix netip.Prefix + +func (i *IPPrefix) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + prefix, err := netip.ParsePrefix(value) + if err != nil { + return err + } + *i = IPPrefix(prefix) + + return nil + default: + return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefix) Value() (driver.Value, error) { + prefixStr := netip.Prefix(i).String() + + return prefixStr, nil +} + +type IPPrefixes []netip.Prefix + +func (i *IPPrefixes) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, i) + + case string: + return json.Unmarshal([]byte(value), i) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i IPPrefixes) Value() (driver.Value, error) { + bytes, err := json.Marshal(i) + + return string(bytes), err +} + +type StringList []string + +func (i *StringList) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, i) + + case string: + return json.Unmarshal([]byte(value), i) + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (i StringList) Value() (driver.Value, error) { + bytes, err := json.Marshal(i) + + return string(bytes), err +} diff --git a/hscontrol/types/machine.go b/hscontrol/types/machine.go new file mode 100644 index 0000000000..a4ca03e019 --- /dev/null +++ b/hscontrol/types/machine.go @@ -0,0 +1,254 @@ +package types + +import ( + "database/sql/driver" + "errors" + "fmt" + "net/netip" + "strings" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "go4.org/netipx" + "google.golang.org/protobuf/types/known/timestamppb" + "tailscale.com/tailcfg" +) + +const ( + // TODO(kradalby): Move out of here when we got circdeps under control. + keepAliveInterval = 60 * time.Second +) + +var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses") + +// Machine is a Headscale client. +type Machine struct { + ID uint64 `gorm:"primary_key"` + MachineKey string `gorm:"type:varchar(64);unique_index"` + NodeKey string + DiscoKey string + IPAddresses MachineAddresses + + // Hostname represents the name given by the Tailscale + // client during registration + Hostname string + + // Givenname represents either: + // a DNS normalized version of Hostname + // a valid name set by the User + // + // GivenName is the name used in all DNS related + // parts of headscale. + GivenName string `gorm:"type:varchar(63);unique_index"` + UserID uint + User User `gorm:"foreignKey:UserID"` + + RegisterMethod string + + ForcedTags StringList + + // TODO(kradalby): This seems like irrelevant information? + AuthKeyID uint + AuthKey *PreAuthKey + + LastSeen *time.Time + LastSuccessfulUpdate *time.Time + Expiry *time.Time + + HostInfo HostInfo + Endpoints StringList + + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type ( + Machines []Machine + MachinesP []*Machine +) + +type MachineAddresses []netip.Addr + +func (ma MachineAddresses) ToStringSlice() []string { + strSlice := make([]string, 0, len(ma)) + for _, addr := range ma { + strSlice = append(strSlice, addr.String()) + } + + return strSlice +} + +// AppendToIPSet adds the individual ips in MachineAddresses to a +// given netipx.IPSetBuilder. +func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { + for _, ip := range ma { + build.Add(ip) + } +} + +func (ma *MachineAddresses) Scan(destination interface{}) error { + switch value := destination.(type) { + case string: + addresses := strings.Split(value, ",") + *ma = (*ma)[:0] + for _, addr := range addresses { + if len(addr) < 1 { + continue + } + parsed, err := netip.ParseAddr(addr) + if err != nil { + return err + } + *ma = append(*ma, parsed) + } + + return nil + + default: + return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination) + } +} + +// Value return json value, implement driver.Valuer interface. +func (ma MachineAddresses) Value() (driver.Value, error) { + addresses := strings.Join(ma.ToStringSlice(), ",") + + return addresses, nil +} + +// IsExpired returns whether the machine registration has expired. +func (machine Machine) IsExpired() bool { + // If Expiry is not set, the client has not indicated that + // it wants an expiry time, it is therefor considered + // to mean "not expired" + if machine.Expiry == nil || machine.Expiry.IsZero() { + return false + } + + return time.Now().UTC().After(*machine.Expiry) +} + +// IsOnline returns if the machine is connected to Headscale. +// This is really a naive implementation, as we don't really see +// if there is a working connection between the client and the server. +func (machine *Machine) IsOnline() bool { + if machine.LastSeen == nil { + return false + } + + if machine.IsExpired() { + return false + } + + return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) +} + +// IsEphemeral returns if the machine is registered as an Ephemeral node. +// https://tailscale.com/kb/1111/ephemeral-nodes/ +func (machine *Machine) IsEphemeral() bool { + return machine.AuthKey != nil && machine.AuthKey.Ephemeral +} + +func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool { + for _, rule := range filter { + // TODO(kradalby): Cache or pregen this + matcher := matcher.MatchFromFilterRule(rule) + + if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) { + continue + } + + if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) { + return true + } + } + + return false +} + +func (machines Machines) FilterByIP(ip netip.Addr) Machines { + found := make(Machines, 0) + + for _, machine := range machines { + for _, mIP := range machine.IPAddresses { + if ip == mIP { + found = append(found, machine) + } + } + } + + return found +} + +func (machine *Machine) Proto() *v1.Machine { + machineProto := &v1.Machine{ + Id: machine.ID, + MachineKey: machine.MachineKey, + + NodeKey: machine.NodeKey, + DiscoKey: machine.DiscoKey, + IpAddresses: machine.IPAddresses.ToStringSlice(), + Name: machine.Hostname, + GivenName: machine.GivenName, + User: machine.User.Proto(), + ForcedTags: machine.ForcedTags, + Online: machine.IsOnline(), + + // TODO(kradalby): Implement register method enum converter + // RegisterMethod: , + + CreatedAt: timestamppb.New(machine.CreatedAt), + } + + if machine.AuthKey != nil { + machineProto.PreAuthKey = machine.AuthKey.Proto() + } + + if machine.LastSeen != nil { + machineProto.LastSeen = timestamppb.New(*machine.LastSeen) + } + + if machine.LastSuccessfulUpdate != nil { + machineProto.LastSuccessfulUpdate = timestamppb.New( + *machine.LastSuccessfulUpdate, + ) + } + + if machine.Expiry != nil { + machineProto.Expiry = timestamppb.New(*machine.Expiry) + } + + return machineProto +} + +// GetHostInfo returns a Hostinfo struct for the machine. +func (machine *Machine) GetHostInfo() tailcfg.Hostinfo { + return tailcfg.Hostinfo(machine.HostInfo) +} + +func (machine Machine) String() string { + return machine.Hostname +} + +func (machines Machines) String() string { + temp := make([]string, len(machines)) + + for index, machine := range machines { + temp[index] = machine.Hostname + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +// TODO(kradalby): Remove when we have generics... +func (machines MachinesP) String() string { + temp := make([]string, len(machines)) + + for index, machine := range machines { + temp[index] = machine.Hostname + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} diff --git a/hscontrol/types/machine_test.go b/hscontrol/types/machine_test.go new file mode 100644 index 0000000000..ab1254f4c2 --- /dev/null +++ b/hscontrol/types/machine_test.go @@ -0,0 +1 @@ +package types diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go new file mode 100644 index 0000000000..0d8c9cff5d --- /dev/null +++ b/hscontrol/types/preauth_key.go @@ -0,0 +1,58 @@ +package types + +import ( + "strconv" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// PreAuthKey describes a pre-authorization key usable in a particular user. +type PreAuthKey struct { + ID uint64 `gorm:"primary_key"` + Key string + UserID uint + User User + Reusable bool + Ephemeral bool `gorm:"default:false"` + Used bool `gorm:"default:false"` + ACLTags []PreAuthKeyACLTag + + CreatedAt *time.Time + Expiration *time.Time +} + +// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. +type PreAuthKeyACLTag struct { + ID uint64 `gorm:"primary_key"` + PreAuthKeyID uint64 + Tag string +} + +func (key *PreAuthKey) Proto() *v1.PreAuthKey { + protoKey := v1.PreAuthKey{ + User: key.User.Name, + Id: strconv.FormatUint(key.ID, util.Base10), + Key: key.Key, + Ephemeral: key.Ephemeral, + Reusable: key.Reusable, + Used: key.Used, + AclTags: make([]string, len(key.ACLTags)), + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + for idx := range key.ACLTags { + protoKey.AclTags[idx] = key.ACLTags[idx].Tag + } + + return &protoKey +} diff --git a/hscontrol/types/routes.go b/hscontrol/types/routes.go new file mode 100644 index 0000000000..1f430712dc --- /dev/null +++ b/hscontrol/types/routes.go @@ -0,0 +1,71 @@ +package types + +import ( + "fmt" + "net/netip" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" +) + +var ( + ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") + ExitRouteV6 = netip.MustParsePrefix("::/0") +) + +type Route struct { + gorm.Model + + MachineID uint64 + Machine Machine + Prefix IPPrefix + + Advertised bool + Enabled bool + IsPrimary bool +} + +type Routes []Route + +func (r *Route) String() string { + return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) +} + +func (r *Route) IsExitRoute() bool { + return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 +} + +func (rs Routes) Prefixes() []netip.Prefix { + prefixes := make([]netip.Prefix, len(rs)) + for i, r := range rs { + prefixes[i] = netip.Prefix(r.Prefix) + } + + return prefixes +} + +func (rs Routes) Proto() []*v1.Route { + protoRoutes := []*v1.Route{} + + for _, route := range rs { + protoRoute := v1.Route{ + Id: uint64(route.ID), + Machine: route.Machine.Proto(), + Prefix: netip.Prefix(route.Prefix).String(), + Advertised: route.Advertised, + Enabled: route.Enabled, + IsPrimary: route.IsPrimary, + CreatedAt: timestamppb.New(route.CreatedAt), + UpdatedAt: timestamppb.New(route.UpdatedAt), + } + + if route.DeletedAt.Valid { + protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) + } + + protoRoutes = append(protoRoutes, &protoRoute) + } + + return protoRoutes +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go new file mode 100644 index 0000000000..d5e3c452e3 --- /dev/null +++ b/hscontrol/types/users.go @@ -0,0 +1,55 @@ +package types + +import ( + "strconv" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +// User is the way Headscale implements the concept of users in Tailscale +// +// At the end of the day, users in Tailscale are some kind of 'bubbles' or users +// that contain our machines. +type User struct { + gorm.Model + Name string `gorm:"unique"` +} + +func (n *User) TailscaleUser() *tailcfg.User { + user := tailcfg.User{ + ID: tailcfg.UserID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + Logins: []tailcfg.LoginID{}, + Created: time.Time{}, + } + + return &user +} + +func (n *User) TailscaleLogin() *tailcfg.Login { + login := tailcfg.Login{ + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + } + + return &login +} + +func (n *User) Proto() *v1.User { + return &v1.User{ + Id: strconv.FormatUint(uint64(n.ID), util.Base10), + Name: n.Name, + CreatedAt: timestamppb.New(n.CreatedAt), + } +} diff --git a/hscontrol/users_test.go b/hscontrol/users_test.go deleted file mode 100644 index 1d68f92fb4..0000000000 --- a/hscontrol/users_test.go +++ /dev/null @@ -1,415 +0,0 @@ -package hscontrol - -import ( - "net/netip" - "testing" - - "gopkg.in/check.v1" - "gorm.io/gorm" -) - -func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(user.Name, check.Equals, "test") - - users, err := app.db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = app.db.DestroyUser("test") - c.Assert(err, check.IsNil) - - _, err = app.db.GetUser("test") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := app.db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) - - user, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - err = app.db.DestroyUser("test") - c.Assert(err, check.IsNil) - - result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key) - // destroying a user also deletes all associated preauthkeys - c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) - - user, err = app.db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - - err = app.db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserStillHasNodes) -} - -func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := app.db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(userTest.Name, check.Equals, "test") - - users, err := app.db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = app.db.RenameUser("test", "test-renamed") - c.Assert(err, check.IsNil) - - _, err = app.db.GetUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) - - _, err = app.db.GetUser("test-renamed") - c.Assert(err, check.IsNil) - - err = app.db.RenameUser("test-does-not-exit", "test") - c.Assert(err, check.Equals, ErrUserNotFound) - - userTest2, err := app.db.CreateUser("test2") - c.Assert(err, check.IsNil) - c.Assert(userTest2.Name, check.Equals, "test2") - - err = app.db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) -} - -func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - userShared1, err := app.db.CreateUser("shared1") - c.Assert(err, check.IsNil) - - userShared2, err := app.db.CreateUser("shared2") - c.Assert(err, check.IsNil) - - userShared3, err := app.db.CreateUser("shared3") - c.Assert(err, check.IsNil) - - preAuthKeyShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared2, err := app.db.CreatePreAuthKey( - userShared2.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared3, err := app.db.CreatePreAuthKey( - userShared3.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKey2Shared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") - c.Assert(err, check.NotNil) - - machineInShared1 := &Machine{ - ID: 1, - MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - Hostname: "test_get_shared_nodes_1", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - AuthKeyID: uint(preAuthKeyShared1.ID), - } - app.db.db.Save(machineInShared1) - - _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) - c.Assert(err, check.IsNil) - - machineInShared2 := &Machine{ - ID: 2, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_2", - UserID: userShared2.ID, - User: *userShared2, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - AuthKeyID: uint(preAuthKeyShared2.ID), - } - app.db.db.Save(machineInShared2) - - _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) - c.Assert(err, check.IsNil) - - machineInShared3 := &Machine{ - ID: 3, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_3", - UserID: userShared3.ID, - User: *userShared3, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - AuthKeyID: uint(preAuthKeyShared3.ID), - } - app.db.db.Save(machineInShared3) - - _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) - c.Assert(err, check.IsNil) - - machine2InShared1 := &Machine{ - ID: 4, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_4", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - AuthKeyID: uint(preAuthKey2Shared1.ID), - } - app.db.db.Save(machine2InShared1) - - peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1) - c.Assert(err, check.IsNil) - - userProfiles := app.db.getMapResponseUserProfiles( - *machineInShared1, - peersOfMachine1InShared1, - ) - - c.Assert(len(userProfiles), check.Equals, 3) - - found := false - for _, userProfiles := range userProfiles { - if userProfiles.DisplayName == userShared1.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) - - found = false - for _, userProfile := range userProfiles { - if userProfile.DisplayName == userShared2.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) -} - -func TestNormalizeToFQDNRules(t *testing.T) { - type args struct { - name string - stripEmailDomain bool - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "normalize simple name", - args: args{ - name: "normalize-simple.name", - stripEmailDomain: false, - }, - want: "normalize-simple.name", - wantErr: false, - }, - { - name: "normalize an email", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: false, - }, - want: "foo.bar.example.com", - wantErr: false, - }, - { - name: "normalize an email domain should be removed", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: true, - }, - want: "foo.bar", - wantErr: false, - }, - { - name: "strip enabled no email passed as argument", - args: args{ - name: "not-email-and-strip-enabled", - stripEmailDomain: true, - }, - want: "not-email-and-strip-enabled", - wantErr: false, - }, - { - name: "normalize complex email", - args: args{ - name: "foo.bar+complex-email@example.com", - stripEmailDomain: false, - }, - want: "foo.bar-complex-email.example.com", - wantErr: false, - }, - { - name: "user name with space", - args: args{ - name: "name space", - stripEmailDomain: false, - }, - want: "name-space", - wantErr: false, - }, - { - name: "user with quote", - args: args{ - name: "Jamie's iPhone 5", - stripEmailDomain: false, - }, - want: "jamies-iphone-5", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) - if (err != nil) != tt.wantErr { - t.Errorf( - "NormalizeToFQDNRules() error = %v, wantErr %v", - err, - tt.wantErr, - ) - - return - } - if got != tt.want { - t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestCheckForFQDNRules(t *testing.T) { - type args struct { - name string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid: user", - args: args{name: "valid-user"}, - wantErr: false, - }, - { - name: "invalid: capitalized user", - args: args{name: "Invalid-CapItaLIzed-user"}, - wantErr: true, - }, - { - name: "invalid: email as user", - args: args{name: "foo.bar@example.com"}, - wantErr: true, - }, - { - name: "invalid: chars in user name", - args: args{name: "super-user+name"}, - wantErr: true, - }, - { - name: "invalid: too long name for user", - args: args{ - name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { - t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := app.db.CreateUser("old") - c.Assert(err, check.IsNil) - - newUser, err := app.db.CreateUser("new") - c.Assert(err, check.IsNil) - - pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - machine := Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: oldUser.ID, - RegisterMethod: RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - app.db.db.Save(&machine) - c.Assert(machine.UserID, check.Equals, oldUser.ID) - - err = app.db.SetMachineUser(&machine, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(machine.UserID, check.Equals, newUser.ID) - c.Assert(machine.User.Name, check.Equals, newUser.Name) - - err = app.db.SetMachineUser(&machine, "non-existing-user") - c.Assert(err, check.Equals, ErrUserNotFound) - - err = app.db.SetMachineUser(&machine, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(machine.UserID, check.Equals, newUser.ID) - c.Assert(machine.User.Name, check.Equals, newUser.Name) -} diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go index d312a6e04d..5c02c9338c 100644 --- a/hscontrol/util/addr.go +++ b/hscontrol/util/addr.go @@ -1,12 +1,94 @@ package util import ( + "fmt" "net/netip" "reflect" + "strings" "go4.org/netipx" ) +// This is borrowed from, and updated to use IPSet +// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162 +// TODO(kradalby): contribute upstream and make public. +var ( + zeroIP4 = netip.AddrFrom4([4]byte{}) + zeroIP6 = netip.AddrFrom16([16]byte{}) +) + +// parseIPSet parses arg as one: +// +// - an IP address (IPv4 or IPv6) +// - the string "*" to match everything (both IPv4 & IPv6) +// - a CIDR (e.g. "192.168.0.0/16") +// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") +// +// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP +// address (without a slash) treated as a CIDR of *bits length. +// nolint +func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) { + var ipSet netipx.IPSetBuilder + if arg == "*" { + ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0)) + ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0)) + + return ipSet.IPSet() + } + if strings.Contains(arg, "/") { + pfx, err := netip.ParsePrefix(arg) + if err != nil { + return nil, err + } + if pfx != pfx.Masked() { + return nil, fmt.Errorf("%v contains non-network bits set", pfx) + } + + ipSet.AddPrefix(pfx) + + return ipSet.IPSet() + } + if strings.Count(arg, "-") == 1 { + ip1s, ip2s, _ := strings.Cut(arg, "-") + + ip1, err := netip.ParseAddr(ip1s) + if err != nil { + return nil, err + } + + ip2, err := netip.ParseAddr(ip2s) + if err != nil { + return nil, err + } + + r := netipx.IPRangeFrom(ip1, ip2) + if !r.IsValid() { + return nil, fmt.Errorf("invalid IP range %q", arg) + } + + for _, prefix := range r.Prefixes() { + ipSet.AddPrefix(prefix) + } + + return ipSet.IPSet() + } + ip, err := netip.ParseAddr(arg) + if err != nil { + return nil, fmt.Errorf("invalid IP address %q", arg) + } + bits8 := uint8(ip.BitLen()) + if bits != nil { + if *bits < 0 || *bits > int(bits8) { + return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg) + } + bits8 = uint8(*bits) + } + + ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8))) + + return ipSet.IPSet() +} + func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { var network, broadcast netip.Addr ipRange := netipx.RangeOfPrefix(na) diff --git a/hscontrol/matcher_test.go b/hscontrol/util/addr_test.go similarity index 96% rename from hscontrol/matcher_test.go rename to hscontrol/util/addr_test.go index fb0e9b076c..45b2b92f89 100644 --- a/hscontrol/matcher_test.go +++ b/hscontrol/util/addr_test.go @@ -1,4 +1,4 @@ -package hscontrol +package util import ( "net/netip" @@ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := parseIPSet(tt.args.arg, tt.args.bits) + got, err := ParseIPSet(tt.args.arg, tt.args.bits) if (err != nil) != tt.wantErr { t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/util/const.go b/hscontrol/util/const.go new file mode 100644 index 0000000000..4f7c811c85 --- /dev/null +++ b/hscontrol/util/const.go @@ -0,0 +1,7 @@ +package util + +const ( + RegisterMethodAuthKey = "authkey" + RegisterMethodOIDC = "oidc" + RegisterMethodCLI = "cli" +) diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go new file mode 100644 index 0000000000..72af8f8357 --- /dev/null +++ b/hscontrol/util/dns.go @@ -0,0 +1,69 @@ +package util + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // value related to RFC 1123 and 952. + LabelHostnameLength = 63 +) + +var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") + +var ErrInvalidUserName = errors.New("invalid user name") + +// NormalizeToFQDNRules will replace forbidden chars in user +// it can also return an error if the user doesn't respect RFC 952 and 1123. +func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > LabelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + ErrInvalidUserName, + ) + } + } + + return name, nil +} + +func CheckForFQDNRules(name string) error { + if len(name) > LabelHostnameLength { + return fmt.Errorf( + "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", + name, + ErrInvalidUserName, + ) + } + if strings.ToLower(name) != name { + return fmt.Errorf( + "DNS segment should be lowercase. %v doesn't comply with this rule: %w", + name, + ErrInvalidUserName, + ) + } + if invalidCharsInUserRegex.MatchString(name) { + return fmt.Errorf( + "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", + name, + ErrInvalidUserName, + ) + } + + return nil +} diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go new file mode 100644 index 0000000000..ab66a13034 --- /dev/null +++ b/hscontrol/util/dns_test.go @@ -0,0 +1,143 @@ +package util + +import "testing" + +func TestNormalizeToFQDNRules(t *testing.T) { + type args struct { + name string + stripEmailDomain bool + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "normalize simple name", + args: args{ + name: "normalize-simple.name", + stripEmailDomain: false, + }, + want: "normalize-simple.name", + wantErr: false, + }, + { + name: "normalize an email", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: false, + }, + want: "foo.bar.example.com", + wantErr: false, + }, + { + name: "normalize an email domain should be removed", + args: args{ + name: "foo.bar@example.com", + stripEmailDomain: true, + }, + want: "foo.bar", + wantErr: false, + }, + { + name: "strip enabled no email passed as argument", + args: args{ + name: "not-email-and-strip-enabled", + stripEmailDomain: true, + }, + want: "not-email-and-strip-enabled", + wantErr: false, + }, + { + name: "normalize complex email", + args: args{ + name: "foo.bar+complex-email@example.com", + stripEmailDomain: false, + }, + want: "foo.bar-complex-email.example.com", + wantErr: false, + }, + { + name: "user name with space", + args: args{ + name: "name space", + stripEmailDomain: false, + }, + want: "name-space", + wantErr: false, + }, + { + name: "user with quote", + args: args{ + name: "Jamie's iPhone 5", + stripEmailDomain: false, + }, + want: "jamies-iphone-5", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) + if (err != nil) != tt.wantErr { + t.Errorf( + "NormalizeToFQDNRules() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + if got != tt.want { + t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckForFQDNRules(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid: user", + args: args{name: "valid-user"}, + wantErr: false, + }, + { + name: "invalid: capitalized user", + args: args{name: "Invalid-CapItaLIzed-user"}, + wantErr: true, + }, + { + name: "invalid: email as user", + args: args{name: "foo.bar@example.com"}, + wantErr: true, + }, + { + name: "invalid: chars in user name", + args: args{name: "super-user+name"}, + wantErr: true, + }, + { + name: "invalid: too long name for user", + args: args{ + name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { + t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/integration/acl_test.go b/integration/acl_test.go index e85e28cd4a..ca184b8388 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -45,7 +45,7 @@ var veryLargeDestination = []string{ "208.0.0.0/4:*", } -func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario { +func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { t.Helper() scenario, err := NewScenario() assert.NoError(t, err) @@ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { // they can access minus one (them self). tests := map[string]struct { users map[string]int - policy hscontrol.ACLPolicy + policy policy.ACLPolicy want map[string]int }{ // Test that when we have no ACL, each client netmap has @@ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, @@ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) { "user1": 2, "user2": 2, }, - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-acl-test": {"user1", "user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"group:integration-acl-test"}, @@ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + &policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"user1"}, @@ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + &policy.ACLPolicy{ + Hosts: policy.Hosts{ "all": netip.MustParsePrefix("100.64.0.0/24"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy hscontrol.ACLPolicy + policy policy.ACLPolicy }{ "ipv4": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("100.64.0.1/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"), "test3": netip.MustParsePrefix("100.64.0.3/32"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) { }, }, "ipv6": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), "test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ // Everyone can curl test3 { Action: "accept", @@ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy hscontrol.ACLPolicy + policy policy.ACLPolicy }{ "ipv4": { - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"100.64.0.1"}, @@ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "ipv6": { - policy: hscontrol.ACLPolicy{ - ACLs: []hscontrol.ACL{ + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"fd7a:115c:a1e0::1"}, @@ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "hostv4cidr": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("100.64.0.1/32"), "test2": netip.MustParsePrefix("100.64.0.2/32"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"test1"}, @@ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "hostv6cidr": { - policy: hscontrol.ACLPolicy{ - Hosts: hscontrol.Hosts{ + policy: policy.ACLPolicy{ + Hosts: policy.Hosts{ "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"test1"}, @@ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { }, }, "group": { - policy: hscontrol.ACLPolicy{ + policy: policy.ACLPolicy{ Groups: map[string][]string{ "group:one": {"user1"}, "group:two": {"user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"group:one"}, diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 0051b40013..d27eb06fcd 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -23,7 +23,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" @@ -60,7 +60,7 @@ type HeadscaleInContainer struct { port int extraPorts []string hostPortBindings map[string][]string - aclPolicy *hscontrol.ACLPolicy + aclPolicy *policy.ACLPolicy env map[string]string tlsCert []byte tlsKey []byte @@ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer) // WithACLPolicy adds a hscontrol.ACLPolicy policy to the // HeadscaleInContainer instance. -func WithACLPolicy(acl *hscontrol.ACLPolicy) Option { +func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 922ced622d..006ac0cbac 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1", "user2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{}, + SSHs: []policy.SSH{}, }, ), hsic.WithTestName("sshnoneconfigured"), @@ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:integration-test": {"user1"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:80"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:integration-test"}, @@ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithSSH()}, hsic.WithACLPolicy( - &hscontrol.ACLPolicy{ + &policy.ACLPolicy{ Groups: map[string][]string{ "group:ssh1": {"useracl1"}, "group:ssh2": {"useracl2"}, }, - ACLs: []hscontrol.ACL{ + ACLs: []policy.ACL{ { Action: "accept", Sources: []string{"*"}, Destinations: []string{"*:*"}, }, }, - SSHs: []hscontrol.SSH{ + SSHs: []policy.SSH{ { Action: "accept", Sources: []string{"group:ssh1"},