Skip to content

Commit

Permalink
Optional ExtendPayload interface support
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Lo-A-Foe <[email protected]>
  • Loading branch information
loafoe committed Feb 7, 2024
1 parent 9a67dbd commit d6efca3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
5 changes: 5 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,8 @@ type RefreshConnector interface {
type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error)
}

// PayloadExtender allows connectors to enhance the payload before signing
type PayloadExtender interface {
ExtendPayload(scopes []string, payload []byte, connectorData []byte) ([]byte, error)
}
12 changes: 6 additions & 6 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,14 +726,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
implicitOrHybrid = true
var err error

accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID, authReq.ConnectorData)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID)
idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID, authReq.ConnectorData)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -941,14 +941,14 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
}

func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID, authCode.ConnectorData)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
}

idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID)
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID, authCode.ConnectorData)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1206,14 +1206,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}

accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID, identity.ConnectorData)
if err != nil {
s.logger.Errorf("password grant failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID)
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID, identity.ConnectorData)
if err != nil {
s.logger.Errorf("password grant failed to create new ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down
23 changes: 20 additions & 3 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,23 @@ type federatedIDClaims struct {
UserID string `json:"user_id,omitempty"`
}

func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) {
return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID)
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string, connectorData []byte) (accessToken string, expiry time.Time, err error) {
return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID, connectorData)
}

func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string, connectorData []byte) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys()
if err != nil {
s.logger.Errorf("Failed to get keys: %v", err)
return "", expiry, err
}

conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", connID, err)
return "", expiry, err
}

signingKey := keys.SigningKey
if signingKey == nil {
return "", expiry, fmt.Errorf("no key to sign payload with")
Expand Down Expand Up @@ -416,6 +422,17 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}

switch c := conn.Connector.(type) {
case connector.PayloadExtender:
extendedPayload, err := c.ExtendPayload(scopes, payload, connectorData)
if err != nil {
s.logger.Warnf("failed to enhance payload: %w", err)
break
}
payload = extendedPayload
default:
}

if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil {
return "", expiry, fmt.Errorf("failed to sign payload: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, rCtx.connectorData)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)
idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID, rCtx.connectorData)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())
Expand Down

0 comments on commit d6efca3

Please sign in to comment.