Skip to content

Commit

Permalink
session_rpcserver: pass all known pairs to RealToPseudo
Browse files Browse the repository at this point in the history
In this commit, we keep track of all known privacy map pairs for a
session along with any new pairs to be persisted.
  • Loading branch information
ellemouton committed Aug 31, 2023
1 parent 741366e commit 44a2ab9
Showing 1 changed file with 75 additions and 49 deletions.
124 changes: 75 additions & 49 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,12 +838,71 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
return nil, fmt.Errorf("expiry must be in the future")
}

// If the privacy mapper is being used for this session, then we need
// to keep track of all our known privacy map pairs for this session
// along with any new pairs that we need to persist.
var (
privacy = !req.NoPrivacyMapper
privacyMapPairs = make(map[string]string)
knownPrivMapPairs = firewalldb.NewPrivacyMapPairs(nil)
newPrivMapPairs = make(map[string]string)
)

// If a previous session ID has been set to link this new one to, we
// first check if we have the referenced session, and we make sure it
// has been revoked.
var (
linkedGroupID *session.ID
linkedGroupSession *session.Session
)
if len(req.LinkedGroupId) != 0 {
var groupID session.ID
copy(groupID[:], req.LinkedGroupId)

// Check that the group actually does exist.
groupSess, err := s.cfg.db.GetSessionByID(groupID)
if err != nil {
return nil, err
}

// Ensure that the linked session is in fact the first session
// in its group.
if groupSess.ID != groupSess.GroupID {
return nil, fmt.Errorf("can not link to session "+
"%x since it is not the first in the session "+
"group %x", groupSess.ID, groupSess.GroupID)
}

// Now we need to check that all the sessions in the group are
// no longer active.
ok, err := s.cfg.db.CheckSessionGroupPredicate(
groupID, func(s *session.Session) bool {
return s.State == session.StateRevoked ||
s.State == session.StateExpired
},
)
if err != nil {
return nil, err
}

if !ok {
return nil, fmt.Errorf("a linked session in group "+
"%x is still active", groupID)
}

linkedGroupID = &groupID
linkedGroupSession = groupSess

privDB := s.cfg.privMap(groupID)
err = privDB.View(func(tx firewalldb.PrivacyMapTx) error {
knownPrivMapPairs, err = tx.FetchAllPairs()

return err
})
if err != nil {
return nil, err
}
}

// First need to fetch all the perms that need to be baked into this
// mac based on the features.
allFeatures, err := s.cfg.autopilot.ListFeatures(ctx)
Expand Down Expand Up @@ -892,8 +951,21 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
return nil, err
}

// Store the new privacy map pairs in
// the newPrivMap pairs map so that
// they are later persisted to the real
// priv map db.
for k, v := range privMapPairs {
privacyMapPairs[k] = v
newPrivMapPairs[k] = v
}

// Also add the new pairs to the known
// set of pairs.
err = knownPrivMapPairs.Add(
privMapPairs,
)
if err != nil {
return nil, err
}
}

Expand Down Expand Up @@ -1017,52 +1089,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
caveats = append(caveats, firewall.MetaPrivacyCaveat)
}

// If a previous session ID has been set to link this new one to, we
// first check if we have the referenced session, and we make sure it
// has been revoked.
var (
linkedGroupID *session.ID
linkedGroupSession *session.Session
)
if len(req.LinkedGroupId) != 0 {
var groupID session.ID
copy(groupID[:], req.LinkedGroupId)

// Check that the group actually does exist.
groupSess, err := s.cfg.db.GetSessionByID(groupID)
if err != nil {
return nil, err
}

// Ensure that the linked session is in fact the first session
// in its group.
if groupSess.ID != groupSess.GroupID {
return nil, fmt.Errorf("can not link to session "+
"%x since it is not the first in the session "+
"group %x", groupSess.ID, groupSess.GroupID)
}

// Now we need to check that all the sessions in the group are
// no longer active.
ok, err := s.cfg.db.CheckSessionGroupPredicate(
groupID, func(s *session.Session) bool {
return s.State == session.StateRevoked ||
s.State == session.StateExpired
},
)
if err != nil {
return nil, err
}

if !ok {
return nil, fmt.Errorf("a linked session in group "+
"%x is still active", groupID)
}

linkedGroupID = &groupID
linkedGroupSession = groupSess
}

s.sessRegMu.Lock()
defer s.sessRegMu.Unlock()

Expand Down Expand Up @@ -1101,7 +1127,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
// Register all the privacy map pairs for this session ID.
privDB := s.cfg.privMap(sess.GroupID)
err = privDB.Update(func(tx firewalldb.PrivacyMapTx) error {
for r, p := range privacyMapPairs {
for r, p := range newPrivMapPairs {
err := tx.NewPair(r, p)
if err != nil {
return err
Expand Down

0 comments on commit 44a2ab9

Please sign in to comment.