Skip to content

Commit

Permalink
remove type wrappers for db
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Sep 30, 2024
1 parent b0f67f5 commit a1d1233
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 143 deletions.
6 changes: 3 additions & 3 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func NewHeadscaleDatabase(

type NodeAux struct {
ID uint64
EnabledRoutes types.IPPrefixes
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
}

nodesAux := []NodeAux{}
Expand All @@ -220,7 +220,7 @@ func NewHeadscaleDatabase(
}

err = tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix).
First(&types.Route{}).
Error
if err == nil {
Expand All @@ -235,7 +235,7 @@ func NewHeadscaleDatabase(
NodeID: node.ID,
Advertised: true,
Enabled: true,
Prefix: types.IPPrefix(prefix),
Prefix: prefix,
}
if err := tx.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
Expand Down
13 changes: 5 additions & 8 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

func TestMigrations(t *testing.T) {
ipp := func(p string) types.IPPrefix {
return types.IPPrefix(netip.MustParsePrefix(p))
ipp := func(p string) netip.Prefix {
return netip.MustParsePrefix(p)
}
r := func(id uint64, p string, a, e, i bool) types.Route {
return types.Route{
Expand Down Expand Up @@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
Expand Down Expand Up @@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
Expand Down
14 changes: 10 additions & 4 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"encoding/json"
"errors"
"fmt"
"net/netip"
Expand Down Expand Up @@ -206,21 +207,26 @@ func SetTags(
) error {
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
}

return nil
}

var newTags types.StringList
var newTags []string
for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) {
newTags = append(newTags, tag)
}
}

if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
b, err := json.Marshal(newTags)
if err != nil {
return err
}

if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}

Expand Down Expand Up @@ -578,7 +584,7 @@ func enableRoutes(tx *gorm.DB,
for _, prefix := range newRoutes {
route := types.Route{}
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
First(&route).Error
if err == nil {
route.Enabled = true
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
c.Assert(node.ForcedTags, check.DeepEquals, sTags)

// assign duplicate tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
Expand All @@ -361,15 +361,15 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(
node.ForcedTags,
check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
[]string{"tag:bar", "tag:test", "tag:unknown"},
)

// test removing tags
err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
}

func TestHeadscale_generateGivenName(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func CreatePreAuthKey(
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: types.StringList(aclTags),
Tags: aclTags,
}

if err := tx.Save(&key).Error; err != nil {
Expand Down
8 changes: 4 additions & 4 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
err := tx.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
Where("prefix = ?", pref.String()).
Find(&routes).Error
if err != nil {
return nil, err
Expand Down Expand Up @@ -285,7 +285,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
var count int64
tx.Model(&types.Route{}).
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.Prefix.String(),
route.NodeID,
true, true).Count(&count)

Expand All @@ -296,7 +296,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
var route types.Route
err := tx.
Preload("Node").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
First(&route).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
Expand Down Expand Up @@ -391,7 +391,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
if !exists {
route := types.Route{
NodeID: node.ID.Uint64(),
Prefix: types.IPPrefix(prefix),
Prefix: prefix,
Advertised: true,
Enabled: false,
}
Expand Down
19 changes: 6 additions & 13 deletions hscontrol/db/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
}

var (
ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
mkNode = func(nid types.NodeID) types.Node {
return types.Node{ID: nid}
}
Expand All @@ -297,7 +297,7 @@ var np = func(nid types.NodeID) *types.Node {
return &no
}

var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
Expand All @@ -309,7 +309,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary
}
}

var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
Expand Down Expand Up @@ -1065,7 +1065,7 @@ func TestFailoverRouteTx(t *testing.T) {
}

func TestFailoverRoute(t *testing.T) {
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
return types.Route{
Model: gorm.Model{
ID: id,
Expand All @@ -1078,7 +1078,7 @@ func TestFailoverRoute(t *testing.T) {
IsPrimary: primary,
}
}
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
Expand Down Expand Up @@ -1201,13 +1201,6 @@ func TestFailoverRoute(t *testing.T) {
},
}

cmps := append(
util.Comparers,
cmp.Comparer(func(x, y types.IPPrefix) bool {
return netip.Prefix(x) == netip.Prefix(y)
}),
)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
Expand All @@ -1231,7 +1224,7 @@ func TestFailoverRoute(t *testing.T) {
"old": gotf.old,
}

if diff := cmp.Diff(want, got, cmps...); diff != "" {
if diff := cmp.Diff(want, got, util.Comparers...); diff != "" {
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
}
}
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/mapper/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,19 @@ func Test_fullMapResponse(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
Prefix: netip.MustParsePrefix("0.0.0.0/0"),
Advertised: true,
Enabled: true,
IsPrimary: false,
},
{
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
Advertised: true,
Enabled: true,
IsPrimary: true,
},
{
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
Advertised: true,
Enabled: false,
IsPrimary: true,
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/mapper/tail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,19 @@ func TestTailNode(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
Prefix: netip.MustParsePrefix("0.0.0.0/0"),
Advertised: true,
Enabled: true,
IsPrimary: false,
},
{
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
Advertised: true,
Enabled: true,
IsPrimary: true,
},
{
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
Advertised: true,
Enabled: false,
IsPrimary: true,
Expand Down
17 changes: 5 additions & 12 deletions hscontrol/policy/acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func TestParsing(t *testing.T) {
],
},
],
}
}
`,
want: []tailcfg.FilterRule{
{
Expand Down Expand Up @@ -2383,7 +2383,7 @@ func TestReduceFilterRules(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: types.StringList{"tag:access-servers"},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
Expand Down Expand Up @@ -3180,7 +3180,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
Routes: types.Routes{
types.Route{
NodeID: 2,
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
IsPrimary: true,
Enabled: true,
},
Expand Down Expand Up @@ -3213,7 +3213,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
Routes: types.Routes{
types.Route{
NodeID: 2,
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
IsPrimary: true,
Enabled: true,
},
Expand All @@ -3223,21 +3223,14 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
}

// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
prefixComparer := cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})
comparers := append([]cmp.Option{}, util.Comparers...)
comparers = append(comparers, prefixComparer)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FilterNodesByACL(
tt.args.node,
tt.args.nodes,
tt.args.rules,
)
if diff := cmp.Diff(tt.want, got, comparers...); diff != "" {
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
}
})
Expand Down
Loading

0 comments on commit a1d1233

Please sign in to comment.