Skip to content

Commit

Permalink
move refresh logic to db layer
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jan 17, 2025
1 parent 38312d2 commit 3f56787
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 116 deletions.
5 changes: 5 additions & 0 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ func (h *Headscale) handleRegister(
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")

// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand Down Expand Up @@ -329,6 +332,8 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if node != nil {
log.Trace().
Expand Down
107 changes: 63 additions & 44 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,64 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
}

func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
// with a registrationID to register or reauthenticate a node.
// If the node found in the registration cache is not already registered,
// it will be registered with the user and the node will be removed from the cache.
// If the node is already registered, the expiry will be updated.
// The node, and a boolean indicating if it was a new node or not, will be returned.
func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
) (*types.Node, bool, error) {
var newNode bool
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
err,
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
err,
)
}

log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")

// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}

reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod

if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
}

node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
)
}

log.Debug().
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")

// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}

reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod

if nodeExpiry != nil {
reg.Node.Expiry = nodeExpiry
if err == nil {
hsdb.regCache.Delete(registrationID)
}

// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
newNode = true
return node, err
} else {
// If the node is already registered, this is a refresh.
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
if err != nil {
return nil, err
}
return node, nil
}

node, err := RegisterNode(
tx,
reg.Node,
ipv4, ipv6,
)

if err == nil {
hsdb.regCache.Delete(registrationID)
}

// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
return node, err
}

return nil, ErrNodeNotFoundRegistrationCache
})

return node, newNode, err
}

func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err)
}

node, err := api.h.db.RegisterNodeFromAuthCallback(
node, _, err := api.h.db.HandleNodeFromAuthPath(
registrationId,
types.UserID(user.ID),
nil,
Expand Down
113 changes: 42 additions & 71 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}

// Retrieve the node and the machine key from the state cache and
// database.
// TODO(kradalby): Is this comment right?
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state)
registrationId := a.getRegistrationIDFromState(state)

// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}

// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}

writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}

return
}

// Register the node if it does not exist.
if registrationId != nil {
if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
if newNode {
verb = "Authenticated"
}

content, err := renderOIDCCallbackTemplate(user)
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -462,33 +440,6 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
return &regInfo.RegistrationID
}

// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
return err
}

ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)

ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)

return nil
}

func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
Expand Down Expand Up @@ -544,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil
}

func (a *AuthProviderOIDC) registerNode(
func (a *AuthProviderOIDC) handleRegistrationID(
user *types.User,
registrationID types.RegistrationID,
expiry time.Time,
) error {
) (bool, error) {
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
return false, err
}

if _, err := a.db.RegisterNodeFromAuthCallback(
node, newNode, err := a.db.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return fmt.Errorf("could not register node: %w", err)
}

err = nodesChangedHook(a.db, a.polMan, a.notifier)
)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
return false, fmt.Errorf("could not register node: %w", err)
}

return nil
// Send an update to all nodes if this is a new node that they need to know
// about.
// If this is a refresh, just send new expiry updates.
if newNode {
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return false, fmt.Errorf("updating resources using node: %w", err)
}
} else {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)

ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
}

return newNode, nil
}

// TODO(kradalby):
// Rewrite in elem-go.
func renderOIDCCallbackTemplate(
user *types.User,
verb string,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Authenticated",
Verb: verb,
}); err != nil {
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}
Expand Down

0 comments on commit 3f56787

Please sign in to comment.