Skip to content

Commit

Permalink
fix: do not retry sending responses (#3764)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored May 8, 2024
1 parent c558e40 commit 1bbfdb5
Showing 1 changed file with 56 additions and 57 deletions.
113 changes: 56 additions & 57 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,71 +960,72 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
return
}

err = h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error {
var err error

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
return err
}
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
}
}

// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
session.Subject = accessRequest.GetClient().GetID()
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()

scopes := accessRequest.GetRequestedScopes()

// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
for _, scope := range accessRequest.GetClient().GetScopes() {
accessRequest.GrantScope(scope)
}
}
// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
session.Subject = accessRequest.GetClient().GetID()
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()

for _, scope := range scopes {
if h.r.Config().GetScopeStrategy(ctx)(accessRequest.GetClient().GetScopes(), scope) {
accessRequest.GrantScope(scope)
}
}
scopes := accessRequest.GetRequestedScopes()

for _, audience := range accessRequest.GetRequestedAudience() {
if h.r.AudienceStrategy()(accessRequest.GetClient().GetAudience(), []string{audience}) == nil {
accessRequest.GrantAudience(audience)
}
// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
for _, scope := range accessRequest.GetClient().GetScopes() {
accessRequest.GrantScope(scope)
}
}

for _, hook := range h.r.AccessRequestHooks() {
if err = hook(ctx, accessRequest); err != nil {
return err
for _, scope := range scopes {
if h.r.Config().GetScopeStrategy(ctx)(accessRequest.GetClient().GetScopes(), scope) {
accessRequest.GrantScope(scope)
}
}

accessResponse, err := h.r.OAuth2Provider().NewAccessResponse(ctx, accessRequest)
if err != nil {
return err
for _, audience := range accessRequest.GetRequestedAudience() {
if h.r.AudienceStrategy()(accessRequest.GetClient().GetAudience(), []string{audience}) == nil {
accessRequest.GrantAudience(audience)
}
}
}

h.r.OAuth2Provider().WriteAccessResponse(ctx, w, accessRequest, accessResponse)

return nil
})
for _, hook := range h.r.AccessRequestHooks() {
if err = hook(ctx, accessRequest); err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
}
}

if err != nil {
var accessResponse fosite.AccessResponder
if err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error {
var err error
accessResponse, err = h.r.OAuth2Provider().NewAccessResponse(ctx, accessRequest)
return err
}); err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
}

h.r.OAuth2Provider().WriteAccessResponse(ctx, w, accessRequest, accessResponse)
}

// swagger:route GET /oauth2/auth oAuth2 oAuth2Authorize
Expand Down Expand Up @@ -1126,8 +1127,9 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
claims.Add("sid", session.ConsentRequest.LoginSessionID)

// done
if err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error {
response, err := h.r.OAuth2Provider().NewAuthorizeResponse(ctx, authorizeRequest, &Session{
var response fosite.AuthorizeResponder
if err := h.r.Persister().Transaction(ctx, func(ctx context.Context, _ *pop.Connection) (err error) {
response, err = h.r.OAuth2Provider().NewAuthorizeResponse(ctx, authorizeRequest, &Session{
DefaultSession: &openid.DefaultSession{
Claims: claims,
Headers: &jwt.Headers{Extra: map[string]interface{}{
Expand All @@ -1145,17 +1147,14 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
MirrorTopLevelClaims: h.c.MirrorTopLevelClaims(ctx),
Flow: flow,
})
if err != nil {
return err
}

h.r.OAuth2Provider().WriteAuthorizeResponse(ctx, w, authorizeRequest, response)
return nil
return err
}); err != nil {
x.LogError(r, err, h.r.Logger())
h.writeAuthorizeError(w, r, authorizeRequest, err)
return
}

h.r.OAuth2Provider().WriteAuthorizeResponse(ctx, w, authorizeRequest, response)
}

// Delete OAuth 2.0 Access Token Parameters
Expand Down

0 comments on commit 1bbfdb5

Please sign in to comment.