Skip to content

Commit

Permalink
Remove variables and leftovers of pregenerated ACL content
Browse files Browse the repository at this point in the history
Prior to the code reorg, we would generate rules from the Policy and
store it on the global object. Now we generate it on the fly for each node
and this commit cleans up the old variables to make sure we have no
unexpected side effects.

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jun 8, 2023
1 parent 084d1d5 commit 725bbd7
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 292 deletions.
42 changes: 0 additions & 42 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ type Headscale struct {
DERPServer *DERPServer

ACLPolicy *policy.ACLPolicy
aclRules []tailcfg.FilterRule
sshPolicy *tailcfg.SSHPolicy

lastStateChange *xsync.MapOf[string, time.Time]

Expand All @@ -102,12 +100,6 @@ type Headscale struct {

stateUpdateChan chan struct{}
cancelStateUpdateChan chan struct{}

// TODO(kradalby): Temporary measure to make sure we can update policy
// across modules, will be removed when aclRules are no longer stored
// globally but generated per node basis.
policyUpdateChan chan struct{}
cancelPolicyUpdateChan chan struct{}
}

func NewHeadscale(cfg *Config) (*Headscale, error) {
Expand Down Expand Up @@ -168,28 +160,22 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
dbString: dbString,
privateKey2019: privateKey,
noisePrivateKey: noisePrivateKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
lastStateChange: xsync.NewMapOf[time.Time](),

stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}),

policyUpdateChan: make(chan struct{}),
cancelPolicyUpdateChan: make(chan struct{}),
}

go app.watchStateChannel()
go app.watchPolicyChannel()

database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
cfg.OIDC.StripEmaildomain,
app.dbDebug,
app.stateUpdateChan,
app.policyUpdateChan,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
Expand Down Expand Up @@ -750,10 +736,6 @@ func (h *Headscale) Serve() error {
close(h.stateUpdateChan)
close(h.cancelStateUpdateChan)

<-h.cancelPolicyUpdateChan
close(h.policyUpdateChan)
close(h.cancelPolicyUpdateChan)

// Close db connections
err = h.db.Close()
if err != nil {
Expand Down Expand Up @@ -862,30 +844,6 @@ func (h *Headscale) watchStateChannel() {
}
}

// TODO(kradalby): baby steps, make this more robust.
func (h *Headscale) watchPolicyChannel() {
for {
select {
case <-h.policyUpdateChan:
machines, err := h.db.ListMachines()
if err != nil {
log.Error().Err(err).Msg("failed to fetch machines during policy update")
}

rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().Err(err).Msg("failed to update ACL rules")
}

h.aclRules = rules
h.sshPolicy = sshPolicy

case <-h.cancelPolicyUpdateChan:
return
}
}
}

func (h *Headscale) setLastStateChangeToNow() {
var err error

Expand Down
11 changes: 4 additions & 7 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ type KV struct {
}

type HSDatabase struct {
db *gorm.DB
notifyStateChan chan<- struct{}
notifyPolicyChan chan<- struct{}
db *gorm.DB
notifyStateChan chan<- struct{}

ipAllocationMutex sync.Mutex

Expand All @@ -53,7 +52,6 @@ func NewHeadscaleDatabase(
dbType, connectionAddr string,
stripEmailDomain, debug bool,
notifyStateChan chan<- struct{},
notifyPolicyChan chan<- struct{},
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
Expand All @@ -63,9 +61,8 @@ func NewHeadscaleDatabase(
}

db := HSDatabase{
db: dbConn,
notifyStateChan: notifyStateChan,
notifyPolicyChan: notifyPolicyChan,
db: dbConn,
notifyStateChan: notifyStateChan,

ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
Expand Down
3 changes: 1 addition & 2 deletions hscontrol/db/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var (
)
)

// ListPeers returns all peers of machine, regardless of any Policy.
// ListPeers returns all peers of machine, regardless of any Policy or if the node is expired.
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
log.Trace().
Caller().
Expand Down Expand Up @@ -218,7 +218,6 @@ func (hsdb *HSDatabase) SetTags(
}
machine.ForcedTags = newTags

hsdb.notifyPolicyChan <- struct{}{}
hsdb.notifyStateChange()

if err := hsdb.db.Save(machine).Error; err != nil {
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func (s *Suite) TestSetTags(c *check.C) {
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
)

c.Assert(channelUpdates, check.Equals, int32(4))
c.Assert(channelUpdates, check.Equals, int32(2))
}

func TestHeadscale_generateGivenName(t *testing.T) {
Expand Down
1 change: 0 additions & 1 deletion hscontrol/db/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func (s *Suite) ResetDB(c *check.C) {
false,
false,
sink,
sink,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
Expand Down
77 changes: 48 additions & 29 deletions hscontrol/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
Expand Down Expand Up @@ -69,33 +70,18 @@ func NewMapper(
}
}

func (m *Mapper) tempWrap(
machine *types.Machine,
pol *policy.ACLPolicy,
) (*tailcfg.MapResponse, error) {
peers, err := m.db.ListPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")

return nil, err
}

return fullMapResponse(
pol,
machine,
peers,
m.stripEmailDomain,
m.baseDomain,
m.dnsCfg,
m.derpMap,
m.logtail,
m.randomClientPort,
)
}

// TODO: Optimise
// As this work continues, the idea is that there will be one Mapper instance
// per node, attached to the open stream between the control and client.
// This means that this can hold a state per machine and we can use that to
// improve the mapresponses sent.
// We could:
// - Keep information about the previous mapresponse so we can send a diff
// - Store hashes
// - Create a "minifier" that removes info not needed for the node

// fullMapResponse is the internal function for generating a MapResponse
// for a machine.
func fullMapResponse(
pol *policy.ACLPolicy,
machine *types.Machine,
Expand All @@ -113,11 +99,23 @@ func fullMapResponse(
return nil, err
}

rules, sshPolicy, err := policy.GenerateFilterRules(pol, peers, stripEmailDomain)
rules, sshPolicy, err := policy.GenerateFilterRules(
pol,
// The policy is currently calculated for the entire Headscale network
append(peers, *machine),
stripEmailDomain,
)
if err != nil {
return nil, err
}

// Filter out peers that have expired.
peers = lo.Filter(peers, func(item types.Machine, index int) bool {
return !item.IsExpired()
})

// If there are filter rules present, see if there are any machines that cannot
// access eachother at all and remove them from the peers.
if len(rules) > 0 {
peers = policy.FilterMachinesByACL(machine, peers, rules)
}
Expand Down Expand Up @@ -278,12 +276,33 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
}
}

// CreateMapResponse returns a MapResponse for the given machine.
func (m Mapper) CreateMapResponse(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
pol *policy.ACLPolicy,
) ([]byte, error) {
mapResponse, err := m.tempWrap(machine, pol)
peers, err := m.db.ListPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")

return nil, err
}

mapResponse, err := fullMapResponse(
pol,
machine,
peers,
m.stripEmailDomain,
m.baseDomain,
m.dnsCfg,
m.derpMap,
m.logtail,
m.randomClientPort,
)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 725bbd7

Please sign in to comment.