diff --git a/server/api.go b/server/api.go index 4454c3ca49..5b0abb0bf5 100644 --- a/server/api.go +++ b/server/api.go @@ -51,7 +51,7 @@ type dexAPI struct { } func (d dexAPI) GetClient(ctx context.Context, req *api.GetClientReq) (*api.GetClientResp, error) { - c, err := d.s.GetClient(req.Id) + c, err := d.s.GetClient(ctx, req.Id) if err != nil { return nil, err } @@ -108,7 +108,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap return nil, errors.New("update client: no client ID supplied") } - err := d.s.UpdateClient(req.Id, func(old storage.Client) (storage.Client, error) { + err := d.s.UpdateClient(ctx, req.Id, func(old storage.Client) (storage.Client, error) { if req.RedirectUris != nil { old.RedirectURIs = req.RedirectUris } @@ -134,7 +134,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap } func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*api.DeleteClientResp, error) { - err := d.s.DeleteClient(req.Id) + err := d.s.DeleteClient(ctx, req.Id) if err != nil { if err == storage.ErrNotFound { return &api.DeleteClientResp{NotFound: true}, nil @@ -219,7 +219,7 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq) return old, nil } - if err := d.s.UpdatePassword(req.Email, updater); err != nil { + if err := d.s.UpdatePassword(ctx, req.Email, updater); err != nil { if err == storage.ErrNotFound { return &api.UpdatePasswordResp{NotFound: true}, nil } @@ -235,7 +235,7 @@ func (d dexAPI) DeletePassword(ctx context.Context, req *api.DeletePasswordReq) return nil, errors.New("no email supplied") } - err := d.s.DeletePassword(req.Email) + err := d.s.DeletePassword(ctx, req.Email) if err != nil { if err == storage.ErrNotFound { return &api.DeletePasswordResp{NotFound: true}, nil @@ -268,7 +268,7 @@ func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.D } func (d dexAPI) ListPasswords(ctx context.Context, req *api.ListPasswordReq) (*api.ListPasswordResp, error) { - passwordList, err := d.s.ListPasswords() + passwordList, err := d.s.ListPasswords(ctx) if err != nil { d.logger.Error("failed to list passwords", "err", err) return nil, fmt.Errorf("list passwords: %v", err) @@ -298,7 +298,7 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq) return nil, errors.New("no password to verify supplied") } - password, err := d.s.GetPassword(req.Email) + password, err := d.s.GetPassword(ctx, req.Email) if err != nil { if err == storage.ErrNotFound { return &api.VerifyPasswordResp{ @@ -327,7 +327,7 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api. return nil, err } - offlineSessions, err := d.s.GetOfflineSessions(id.UserId, id.ConnId) + offlineSessions, err := d.s.GetOfflineSessions(ctx, id.UserId, id.ConnId) if err != nil { if err == storage.ErrNotFound { // This means that this user-client pair does not have a refresh token yet. @@ -381,7 +381,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (* return old, nil } - if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil { + if err := d.s.UpdateOfflineSessions(ctx, id.UserId, id.ConnId, updater); err != nil { if err == storage.ErrNotFound { return &api.RevokeRefreshResp{NotFound: true}, nil } @@ -397,7 +397,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (* // // TODO(ericchiang): we don't have any good recourse if this call fails. // Consider garbage collection of refresh tokens with no associated ref. - if err := d.s.DeleteRefresh(refreshID); err != nil { + if err := d.s.DeleteRefresh(ctx, refreshID); err != nil { d.logger.Error("failed to delete refresh token", "err", err) return nil, err } @@ -448,7 +448,7 @@ func (d dexAPI) CreateConnector(ctx context.Context, req *api.CreateConnectorReq return &api.CreateConnectorResp{}, nil } -func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) { +func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) { if !featureflags.APIConnectorsCRUD.Enabled() { return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name) } @@ -485,7 +485,7 @@ func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq) return old, nil } - if err := d.s.UpdateConnector(req.Id, updater); err != nil { + if err := d.s.UpdateConnector(ctx, req.Id, updater); err != nil { if err == storage.ErrNotFound { return &api.UpdateConnectorResp{NotFound: true}, nil } @@ -505,7 +505,7 @@ func (d dexAPI) DeleteConnector(ctx context.Context, req *api.DeleteConnectorReq return nil, errors.New("no id supplied") } - err := d.s.DeleteConnector(req.Id) + err := d.s.DeleteConnector(ctx, req.Id) if err != nil { if err == storage.ErrNotFound { return &api.DeleteConnectorResp{NotFound: true}, nil @@ -521,7 +521,7 @@ func (d dexAPI) ListConnectors(ctx context.Context, req *api.ListConnectorReq) ( return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name) } - connectorList, err := d.s.ListConnectors() + connectorList, err := d.s.ListConnectors(ctx) if err != nil { d.logger.Error("api: failed to list connectors", "err", err) return nil, fmt.Errorf("list connectors: %v", err) diff --git a/server/api_test.go b/server/api_test.go index af8d363f9f..d929ff3128 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -149,7 +149,7 @@ func TestPassword(t *testing.T) { t.Fatalf("Unable to update password: %v", err) } - pass, err := s.GetPassword(updateReq.Email) + pass, err := s.GetPassword(ctx, updateReq.Email) if err != nil { t.Fatalf("Unable to retrieve password: %v", err) } @@ -449,7 +449,7 @@ func TestUpdateClient(t *testing.T) { t.Errorf("expected in response NotFound: %t", tc.want.NotFound) } - client, err := s.GetClient(tc.req.Id) + client, err := s.GetClient(ctx, tc.req.Id) if err != nil { t.Errorf("no client found in the storage: %v", err) } diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 06f3a7b2d5..380e40aacb 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -199,6 +199,7 @@ func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Requ } func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() deviceCode := r.Form.Get("device_code") if deviceCode == "" { s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) @@ -208,7 +209,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { now := s.now() // Grab the device token, check validity - deviceToken, err := s.storage.GetDeviceToken(deviceCode) + deviceToken, err := s.storage.GetDeviceToken(ctx, deviceCode) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err) @@ -240,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { return old, nil } // Update device token last request time in storage - if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { + if err := s.storage.UpdateDeviceToken(ctx, deviceCode, updater); err != nil { s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err) s.renderError(r, w, http.StatusInternalServerError, "") return @@ -299,7 +300,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { return } - authCode, err := s.storage.GetAuthCode(code) + authCode, err := s.storage.GetAuthCode(ctx, code) if err != nil || s.now().After(authCode.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { @@ -311,7 +312,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { } // Grab the device request from storage - deviceReq, err := s.storage.GetDeviceRequest(userCode) + deviceReq, err := s.storage.GetDeviceRequest(ctx, userCode) if err != nil || s.now().After(deviceReq.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { @@ -322,7 +323,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { return } - client, err := s.storage.GetClient(deviceReq.ClientID) + client, err := s.storage.GetClient(ctx, deviceReq.ClientID) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get client", "err", err) @@ -345,7 +346,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { } // Grab the device token from storage - old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode) + old, err := s.storage.GetDeviceToken(ctx, deviceReq.DeviceCode) if err != nil || s.now().After(old.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { @@ -373,7 +374,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { } // Update refresh token in the storage, store the token and mark as complete - if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { + if err := s.storage.UpdateDeviceToken(ctx, deviceReq.DeviceCode, updater); err != nil { s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err) s.renderError(r, w, http.StatusBadRequest, "") return @@ -391,6 +392,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { } func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() switch r.Method { case http.MethodPost: err := r.ParseForm() @@ -409,7 +411,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { userCode = strings.ToUpper(userCode) // Find the user code in the available requests - deviceRequest, err := s.storage.GetDeviceRequest(userCode) + deviceRequest, err := s.storage.GetDeviceRequest(ctx, userCode) if err != nil || s.now().After(deviceRequest.Expiry) { if err != nil && err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err) diff --git a/server/handlers.go b/server/handlers.go index 5954820caa..a00b290b61 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -32,8 +32,9 @@ const ( ) func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() // TODO(ericchiang): Cache this. - keys, err := s.storage.GetKeys() + keys, err := s.storage.GetKeys(ctx) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get keys", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") @@ -135,6 +136,7 @@ func (s *Server) constructDiscovery() discovery { // handleAuthorization handles the OAuth2 auth endpoint. func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() // Extract the arguments if err := r.ParseForm(); err != nil { s.logger.ErrorContext(r.Context(), "failed to parse arguments", "err", err) @@ -144,8 +146,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { } connectorID := r.Form.Get("connector_id") - - connectors, err := s.storage.ListConnectors() + connectors, err := s.storage.ListConnectors(ctx) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get list of connectors", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") @@ -219,7 +220,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } - conn, err := s.getConnector(connID) + conn, err := s.getConnector(ctx, connID) if err != nil { s.logger.ErrorContext(r.Context(), "Failed to get connector", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") @@ -314,6 +315,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() authID := r.URL.Query().Get("state") if authID == "" { s.renderError(r, w, http.StatusBadRequest, "User session error.") @@ -322,7 +324,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { backLink := r.URL.Query().Get("back") - authReq, err := s.storage.GetAuthRequest(authID) + authReq, err := s.storage.GetAuthRequest(ctx, authID) if err != nil { if err == storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err) @@ -345,7 +347,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { return } - conn, err := s.getConnector(authReq.ConnectorID) + conn, err := s.getConnector(ctx, authReq.ConnectorID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") @@ -390,7 +392,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { } if canSkipApproval { - authReq, err = s.storage.GetAuthRequest(authReq.ID) + authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -425,7 +427,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - authReq, err := s.storage.GetAuthRequest(authID) + authReq, err := s.storage.GetAuthRequest(ctx, authID) if err != nil { if err == storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err) @@ -448,7 +450,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - conn, err := s.getConnector(authReq.ConnectorID) + conn, err := s.getConnector(ctx, authReq.ConnectorID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") @@ -490,7 +492,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) } if canSkipApproval { - authReq, err = s.storage.GetAuthRequest(authReq.ID) + authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -521,7 +523,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, a.ConnectorData = identity.ConnectorData return a, nil } - if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { + if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, updater); err != nil { return "", false, fmt.Errorf("failed to update auth request: %v", err) } @@ -545,7 +547,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, if offlineAccessRequested && canRefresh { // Try to retrieve an existing OfflineSession object for the corresponding user. - session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) + session, err := s.storage.GetOfflineSessions(ctx, identity.UserID, authReq.ConnectorID) switch { case err != nil && err == storage.ErrNotFound: offlineSessions := storage.OfflineSessions{ @@ -563,7 +565,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } case err == nil: // Update existing OfflineSession obj with new RefreshTokenRef. - if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { if len(identity.ConnectorData) > 0 { old.ConnectorData = identity.ConnectorData } @@ -594,6 +596,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() macEncoded := r.FormValue("hmac") if macEncoded == "" { s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request") @@ -605,7 +608,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { return } - authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) + authReq, err := s.storage.GetAuthRequest(ctx, r.FormValue("req")) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Database error.") @@ -629,7 +632,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - client, err := s.storage.GetClient(authReq.ClientID) + client, err := s.storage.GetClient(ctx, authReq.ClientID) if err != nil { s.logger.ErrorContext(r.Context(), "Failed to get client", "client_id", authReq.ClientID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.") @@ -654,7 +657,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe return } - if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil { + if err := s.storage.DeleteAuthRequest(ctx, authReq.ID); err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "Failed to delete authorization request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") @@ -786,6 +789,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe } func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) { + ctx := r.Context() clientID, clientSecret, ok := r.BasicAuth() if ok { var err error @@ -802,7 +806,7 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h clientSecret = r.PostFormValue("client_secret") } - client, err := s.storage.GetClient(clientID) + client, err := s.storage.GetClient(ctx, clientID) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get client", "err", err) @@ -885,7 +889,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } - authCode, err := s.storage.GetAuthCode(code) + authCode, err := s.storage.GetAuthCode(ctx, code) if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err) @@ -950,7 +954,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au return nil, err } - if err := s.storage.DeleteAuthCode(authCode.ID); err != nil { + if err := s.storage.DeleteAuthCode(ctx, authCode.ID); err != nil { s.logger.ErrorContext(ctx, "failed to delete auth code", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err @@ -960,7 +964,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au // Ensure the connector supports refresh tokens. // // Connectors like `saml` do not implement RefreshConnector. - conn, err := s.getConnector(authCode.ConnectorID) + conn, err := s.getConnector(ctx, authCode.ConnectorID) if err != nil { s.logger.ErrorContext(ctx, "connector not found", "connector_id", authCode.ConnectorID, "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1016,7 +1020,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au defer func() { if deleteToken { // Delete newly created refresh token from storage. - if err := s.storage.DeleteRefresh(refresh.ID); err != nil { + if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil { s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return @@ -1032,7 +1036,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au } // Try to retrieve an existing OfflineSession object for the corresponding user. - if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1057,7 +1061,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au } else { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { // Delete old refresh token from storage. - if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound { + if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil && err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true @@ -1066,7 +1070,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au } // Update existing OfflineSession obj with new RefreshTokenRef. - if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { old.Refresh[tokenRef.ClientID] = &tokenRef return old, nil }); err != nil { @@ -1140,7 +1144,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli continue } - isTrusted, err := s.validateCrossClientTrust(r.Context(), client.ID, peerID) + isTrusted, err := s.validateCrossClientTrust(ctx, client.ID, peerID) if err != nil { s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest) return @@ -1165,7 +1169,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli // Which connector connID := s.passwordConnector - conn, err := s.getConnector(connID) + conn, err := s.getConnector(ctx, connID) if err != nil { s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) return @@ -1201,14 +1205,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Groups: identity.Groups, } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID) + accessToken, _, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID) + idToken, expiry, err := s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1268,7 +1272,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli defer func() { if deleteToken { // Delete newly created refresh token from storage. - if err := s.storage.DeleteRefresh(refresh.ID); err != nil { + if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil { s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return @@ -1284,7 +1288,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } // Try to retrieve an existing OfflineSession object for the corresponding user. - if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(r.Context(), "failed to get offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1310,7 +1314,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } else { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { // Delete old refresh token from storage. - if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { + if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil { if err == storage.ErrNotFound { s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID) } else { @@ -1323,7 +1327,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } // Update existing OfflineSession obj with new RefreshTokenRef. - if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { old.Refresh[tokenRef.ClientID] = &tokenRef old.ConnectorData = identity.ConnectorData return old, nil @@ -1371,7 +1375,7 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli return } - conn, err := s.getConnector(connID) + conn, err := s.getConnector(ctx, connID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err) s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) diff --git a/server/handlers_test.go b/server/handlers_test.go index 08b02f75c9..1aa4bfa58a 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -138,7 +138,7 @@ type emptyStorage struct { storage.Storage } -func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) { +func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthRequest, error) { return storage.AuthRequest{}, storage.ErrNotFound } @@ -407,7 +407,7 @@ func TestHandlePassword(t *testing.T) { err := json.Unmarshal(rr.Body.Bytes(), &ref) require.NoError(t, err) - newSess, err := s.storage.GetOfflineSessions("0-385-28089-0", "test") + newSess, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", "test") if tc.offlineSessionCreated { require.NoError(t, err) require.Equal(t, `{"test": "true"}`, string(newSess.ConnectorData)) @@ -562,7 +562,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { cb, _ := url.Parse(resp.Header.Get("Location")) require.Equal(t, tc.expectedRes, cb.Path) - offlineSession, err := s.storage.GetOfflineSessions("0-385-28089-0", connID) + offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID) if tc.offlineSessionCreated { require.NoError(t, err) require.NotEmpty(t, offlineSession) @@ -701,7 +701,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { cb, _ := url.Parse(resp.Header.Get("Location")) require.Equal(t, tc.expectedRes, cb.Path) - offlineSession, err := s.storage.GetOfflineSessions("0-385-28089-0", connID) + offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID) if tc.offlineSessionCreated { require.NoError(t, err) require.NotEmpty(t, offlineSession) diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 802e29b6e7..42ad1b3c70 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -263,7 +263,7 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr return nil, newIntrospectInternalServerError() } - client, err := s.storage.GetClient(clientID) + client, err := s.storage.GetClient(ctx, clientID) if err != nil { s.logger.ErrorContext(ctx, "error while fetching client from storage", "err", err.Error()) return nil, newIntrospectInternalServerError() diff --git a/server/oauth2.go b/server/oauth2.go index cc81a8a52d..18cc3dd46d 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -351,7 +351,7 @@ func genSubject(userID string, connID string) (string, error) { } func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { - keys, err := s.storage.GetKeys() + keys, err := s.storage.GetKeys(ctx) if err != nil { s.logger.ErrorContext(ctx, "failed to get keys", "err", err) return "", expiry, err @@ -453,6 +453,7 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage // parse the initial request from the OAuth2 client. func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) { + ctx := r.Context() if err := r.ParseForm(); err != nil { return nil, newDisplayedErr(http.StatusBadRequest, "Failed to parse request.") } @@ -477,7 +478,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques codeChallengeMethod = codeChallengeMethodPlain } - client, err := s.storage.GetClient(clientID) + client, err := s.storage.GetClient(ctx, clientID) if err != nil { if err == storage.ErrNotFound { return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID) @@ -499,7 +500,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } if connectorID != "" { - connectors, err := s.storage.ListConnectors() + connectors, err := s.storage.ListConnectors(ctx) if err != nil { s.logger.ErrorContext(r.Context(), "failed to list connectors", "err", err) return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors") @@ -634,7 +635,7 @@ func (s *Server) validateCrossClientTrust(ctx context.Context, clientID, peerID if peerID == clientID { return true, nil } - peer, err := s.storage.GetClient(peerID) + peer, err := s.storage.GetClient(ctx, peerID) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get client", "err", err) @@ -707,7 +708,7 @@ type storageKeySet struct { storage.Storage } -func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { +func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) if err != nil { return nil, err @@ -719,7 +720,7 @@ func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload break } - skeys, err := s.Storage.GetKeys() + skeys, err := s.Storage.GetKeys(ctx) if err != nil { return nil, err } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 2b733df825..70e4095c86 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -599,7 +599,7 @@ func TestValidRedirectURI(t *testing.T) { func TestStorageKeySet(t *testing.T) { s := memory.New(logger) - if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) { + if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) { keys.SigningKey = &jose.JSONWebKey{ Key: testKey, KeyID: "testkey", diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 391d552251..de8d9b7b8d 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -84,7 +84,7 @@ func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *strin refreshCtx := refreshContext{requestToken: token} // Get RefreshToken - refresh, err := s.storage.GetRefresh(token.RefreshId) + refresh, err := s.storage.GetRefresh(ctx, token.RefreshId) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get refresh token", "err", err) @@ -126,14 +126,14 @@ func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *strin refreshCtx.storageToken = &refresh // Get Connector - refreshCtx.connector, err = s.getConnector(refresh.ConnectorID) + refreshCtx.connector, err = s.getConnector(ctx, refresh.ConnectorID) if err != nil { s.logger.ErrorContext(ctx, "connector not found", "connector_id", refresh.ConnectorID, "err", err) return nil, newInternalServerError() } // Get Connector Data - session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) + session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID) switch { case err != nil: if err != storage.ErrNotFound { @@ -223,7 +223,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr // Update LastUsed time stamp in refresh token reference object // in offline session for the user. - err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + err := s.storage.UpdateOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) if err != nil { s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) return newInternalServerError() @@ -314,7 +314,7 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( } // Update refresh token in the storage. - err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater) + err := s.storage.UpdateRefreshToken(ctx, rCtx.storageToken.ID, refreshTokenUpdater) if err != nil { s.logger.ErrorContext(ctx, "failed to update refresh token", "err", err) return nil, ident, newInternalServerError() diff --git a/server/rotation.go b/server/rotation.go index dfd776d677..286b4b57af 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -95,7 +95,7 @@ func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy } func (k keyRotator) rotate() error { - keys, err := k.GetKeys() + keys, err := k.GetKeys(context.Background()) if err != nil && err != storage.ErrNotFound { return fmt.Errorf("get keys: %v", err) } @@ -128,7 +128,7 @@ func (k keyRotator) rotate() error { } var nextRotation time.Time - err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) { + err = k.Storage.UpdateKeys(context.Background(), func(keys storage.Keys) (storage.Keys, error) { tNow := k.now() // if you are running multiple instances of dex, another instance diff --git a/server/rotation_test.go b/server/rotation_test.go index 1d0d2f100a..17e06c6d9d 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "io" "log/slog" "sort" @@ -14,7 +15,7 @@ import ( ) func signingKeyID(t *testing.T, s storage.Storage) string { - keys, err := s.GetKeys() + keys, err := s.GetKeys(context.TODO()) if err != nil { t.Fatal(err) } @@ -22,7 +23,7 @@ func signingKeyID(t *testing.T, s storage.Storage) string { } func verificationKeyIDs(t *testing.T, s storage.Storage) (ids []string) { - keys, err := s.GetKeys() + keys, err := s.GetKeys(context.TODO()) if err != nil { t.Fatal(err) } diff --git a/server/server.go b/server/server.go index 5c5faa3003..8c0462969a 100644 --- a/server/server.go +++ b/server/server.go @@ -316,7 +316,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // Retrieves connector objects in backend storage. This list includes the static connectors // defined in the ConfigMap and dynamic connectors retrieved from the storage. - storageConnectors, err := c.Storage.ListConnectors() + storageConnectors, err := c.Storage.ListConnectors(ctx) if err != nil { return nil, fmt.Errorf("server: failed to list connector objects from storage: %v", err) } @@ -535,7 +535,7 @@ type passwordDB struct { } func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) { - p, err := db.s.GetPassword(email) + p, err := db.s.GetPassword(ctx, email) if err != nil { if err != storage.ErrNotFound { return connector.Identity{}, false, fmt.Errorf("get password: %v", err) @@ -560,7 +560,7 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { // If the user has been deleted, the refresh token will be rejected. - p, err := db.s.GetPassword(identity.Email) + p, err := db.s.GetPassword(ctx, identity.Email) if err != nil { if err == storage.ErrNotFound { return connector.Identity{}, errors.New("user not found") @@ -602,13 +602,13 @@ type keyCacher struct { keys atomic.Value // Always holds nil or type *storage.Keys. } -func (k *keyCacher) GetKeys() (storage.Keys, error) { +func (k *keyCacher) GetKeys(ctx context.Context) (storage.Keys, error) { keys, ok := k.keys.Load().(*storage.Keys) if ok && keys != nil && k.now().Before(keys.NextRotation) { return *keys, nil } - storageKeys, err := k.Storage.GetKeys() + storageKeys, err := k.Storage.GetKeys(ctx) if err != nil { return storageKeys, err } @@ -626,7 +626,7 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura case <-ctx.Done(): return case <-time.After(frequency): - if r, err := s.storage.GarbageCollect(now()); err != nil { + if r, err := s.storage.GarbageCollect(ctx, now()); err != nil { s.logger.ErrorContext(ctx, "garbage collection failed", "err", err) } else if !r.IsEmpty() { s.logger.InfoContext(ctx, "garbage collection run, delete auth", @@ -719,8 +719,8 @@ func (s *Server) OpenConnector(conn storage.Connector) (Connector, error) { // getConnector retrieves the connector object with the given id from the storage // and updates the connector list for server if necessary. -func (s *Server) getConnector(id string) (Connector, error) { - storageConnector, err := s.storage.GetConnector(id) +func (s *Server) getConnector(ctx context.Context, id string) (Connector, error) { + storageConnector, err := s.storage.GetConnector(ctx, id) if err != nil { return Connector{}, fmt.Errorf("failed to get connector object from storage: %v", err) } diff --git a/server/server_test.go b/server/server_test.go index 8936c90a07..1e8c4df0cf 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -875,7 +875,7 @@ func TestOAuth2CodeFlow(t *testing.T) { t.Fatal(err) } - tokens, err := s.storage.ListRefreshTokens() + tokens, err := s.storage.ListRefreshTokens(ctx) if err != nil { t.Fatalf("failed to get existed refresh token: %v", err) } @@ -1369,15 +1369,15 @@ type storageWithKeysTrigger struct { f func() } -func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) { +func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, error) { s.f() - return s.Storage.GetKeys() + return s.Storage.GetKeys(ctx) } func TestKeyCacher(t *testing.T) { tNow := time.Now() now := func() time.Time { return tNow } - + ctx := context.TODO() s := memory.New(logger) tests := []struct { @@ -1390,7 +1390,7 @@ func TestKeyCacher(t *testing.T) { }, { before: func() { - s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { + s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { old.NextRotation = tNow.Add(time.Minute) return old, nil }) @@ -1410,7 +1410,7 @@ func TestKeyCacher(t *testing.T) { { before: func() { tNow = tNow.Add(time.Hour) - s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { + s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { old.NextRotation = tNow.Add(time.Minute) return old, nil }) @@ -1428,7 +1428,7 @@ func TestKeyCacher(t *testing.T) { for i, tc := range tests { gotCall = false tc.before() - s.GetKeys() + s.GetKeys(context.TODO()) if gotCall != tc.wantCallToStorage { t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall) } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 84ad1cba5f..58ae3d958d 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -148,7 +148,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed creating auth request: %v", err) } - if err := s.UpdateAuthRequest(a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { + if err := s.UpdateAuthRequest(ctx, a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.Claims = identity old.ConnectorID = "connID" return old, nil @@ -156,7 +156,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed to update auth request: %v", err) } - got, err := s.GetAuthRequest(a1.ID) + got, err := s.GetAuthRequest(ctx, a1.ID) if err != nil { t.Fatalf("failed to get auth req: %v", err) } @@ -168,15 +168,15 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) } - if err := s.DeleteAuthRequest(a1.ID); err != nil { + if err := s.DeleteAuthRequest(ctx, a1.ID); err != nil { t.Fatalf("failed to delete auth request: %v", err) } - if err := s.DeleteAuthRequest(a2.ID); err != nil { + if err := s.DeleteAuthRequest(ctx, a2.ID); err != nil { t.Fatalf("failed to delete auth request: %v", err) } - _, err = s.GetAuthRequest(a1.ID) + _, err = s.GetAuthRequest(ctx, a1.ID) mustBeErrNotFound(t, "auth request", err) } @@ -234,7 +234,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { t.Fatalf("failed creating auth code: %v", err) } - got, err := s.GetAuthCode(a1.ID) + got, err := s.GetAuthCode(ctx, a1.ID) if err != nil { t.Fatalf("failed to get auth code: %v", err) } @@ -246,15 +246,15 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { t.Errorf("auth code retrieved from storage did not match: %s", diff) } - if err := s.DeleteAuthCode(a1.ID); err != nil { + if err := s.DeleteAuthCode(ctx, a1.ID); err != nil { t.Fatalf("delete auth code: %v", err) } - if err := s.DeleteAuthCode(a2.ID); err != nil { + if err := s.DeleteAuthCode(ctx, a2.ID); err != nil { t.Fatalf("delete auth code: %v", err) } - _, err = s.GetAuthCode(a1.ID) + _, err = s.GetAuthCode(ctx, a1.ID) mustBeErrNotFound(t, "auth code", err) } @@ -268,7 +268,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", } - err := s.DeleteClient(id1) + err := s.DeleteClient(ctx, id1) mustBeErrNotFound(t, "client", err) if err := s.CreateClient(ctx, c1); err != nil { @@ -293,7 +293,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { } getAndCompare := func(_ string, want storage.Client) { - gc, err := s.GetClient(id1) + gc, err := s.GetClient(ctx, id1) if err != nil { t.Errorf("get client: %v", err) return @@ -306,7 +306,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { getAndCompare(id1, c1) newSecret := "barfoo" - err = s.UpdateClient(id1, func(old storage.Client) (storage.Client, error) { + err = s.UpdateClient(ctx, id1, func(old storage.Client) (storage.Client, error) { old.Secret = newSecret return old, nil }) @@ -316,15 +316,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) { c1.Secret = newSecret getAndCompare(id1, c1) - if err := s.DeleteClient(id1); err != nil { + if err := s.DeleteClient(ctx, id1); err != nil { t.Fatalf("delete client: %v", err) } - if err := s.DeleteClient(id2); err != nil { + if err := s.DeleteClient(ctx, id2); err != nil { t.Fatalf("delete client: %v", err) } - _, err = s.GetClient(id1) + _, err = s.GetClient(ctx, id1) mustBeErrNotFound(t, "client", err) } @@ -359,7 +359,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { mustBeErrAlreadyExists(t, "refresh token", err) getAndCompare := func(id string, want storage.RefreshToken) { - gr, err := s.GetRefresh(id) + gr, err := s.GetRefresh(ctx, id) if err != nil { t.Errorf("get refresh: %v", err) return @@ -419,7 +419,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { r.LastUsed = updatedAt return r, nil } - if err := s.UpdateRefreshToken(id, updater); err != nil { + if err := s.UpdateRefreshToken(ctx, id, updater); err != nil { t.Errorf("failed to update refresh token: %v", err) } refresh.Token = "spam" @@ -429,15 +429,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { // Ensure that updating the first token doesn't impact the second. Issue #847. getAndCompare(id2, refresh2) - if err := s.DeleteRefresh(id); err != nil { + if err := s.DeleteRefresh(ctx, id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } - if err := s.DeleteRefresh(id2); err != nil { + if err := s.DeleteRefresh(ctx, id2); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } - _, err = s.GetRefresh(id) + _, err = s.GetRefresh(ctx, id) mustBeErrNotFound(t, "refresh token", err) } @@ -485,7 +485,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { } getAndCompare := func(id string, want storage.Password) { - gr, err := s.GetPassword(id) + gr, err := s.GetPassword(ctx, id) if err != nil { t.Errorf("get password %q: %v", id, err) return @@ -498,7 +498,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { getAndCompare("jane@example.com", password1) getAndCompare("JANE@example.com", password1) // Emails should be case insensitive - if err := s.UpdatePassword(password1.Email, func(old storage.Password) (storage.Password, error) { + if err := s.UpdatePassword(ctx, password1.Email, func(old storage.Password) (storage.Password, error) { old.Username = "jane doe" return old, nil }); err != nil { @@ -512,7 +512,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { passwordList = append(passwordList, password1, password2) listAndCompare := func(want []storage.Password) { - passwords, err := s.ListPasswords() + passwords, err := s.ListPasswords(ctx) if err != nil { t.Errorf("list password: %v", err) return @@ -526,15 +526,15 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { listAndCompare(passwordList) - if err := s.DeletePassword(password1.Email); err != nil { + if err := s.DeletePassword(ctx, password1.Email); err != nil { t.Fatalf("failed to delete password: %v", err) } - if err := s.DeletePassword(password2.Email); err != nil { + if err := s.DeletePassword(ctx, password2.Email); err != nil { t.Fatalf("failed to delete password: %v", err) } - _, err = s.GetPassword(password1.Email) + _, err = s.GetPassword(ctx, password1.Email) mustBeErrNotFound(t, "password", err) } @@ -571,7 +571,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { } getAndCompare := func(userID string, connID string, want storage.OfflineSessions) { - gr, err := s.GetOfflineSessions(userID, connID) + gr, err := s.GetOfflineSessions(ctx, userID, connID) if err != nil { t.Errorf("get offline session: %v", err) return @@ -592,7 +592,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { } session1.Refresh[tokenRef.ClientID] = &tokenRef - if err := s.UpdateOfflineSessions(session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if err := s.UpdateOfflineSessions(ctx, session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { old.Refresh[tokenRef.ClientID] = &tokenRef return old, nil }); err != nil { @@ -601,15 +601,15 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { getAndCompare(userID1, "Conn1", session1) - if err := s.DeleteOfflineSessions(session1.UserID, session1.ConnID); err != nil { + if err := s.DeleteOfflineSessions(ctx, session1.UserID, session1.ConnID); err != nil { t.Fatalf("failed to delete offline session: %v", err) } - if err := s.DeleteOfflineSessions(session2.UserID, session2.ConnID); err != nil { + if err := s.DeleteOfflineSessions(ctx, session2.UserID, session2.ConnID); err != nil { t.Fatalf("failed to delete offline session: %v", err) } - _, err = s.GetOfflineSessions(session1.UserID, session1.ConnID) + _, err = s.GetOfflineSessions(ctx, session1.UserID, session1.ConnID) mustBeErrNotFound(t, "offline session", err) } @@ -646,7 +646,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { } getAndCompare := func(id string, want storage.Connector) { - gr, err := s.GetConnector(id) + gr, err := s.GetConnector(ctx, id) if err != nil { t.Errorf("get connector: %v", err) return @@ -660,7 +660,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { getAndCompare(id1, c1) - if err := s.UpdateConnector(c1.ID, func(old storage.Connector) (storage.Connector, error) { + if err := s.UpdateConnector(ctx, c1.ID, func(old storage.Connector) (storage.Connector, error) { old.Type = "oidc" return old, nil }); err != nil { @@ -672,7 +672,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { connectorList := []storage.Connector{c1, c2} listAndCompare := func(want []storage.Connector) { - connectors, err := s.ListConnectors() + connectors, err := s.ListConnectors(ctx) if err != nil { t.Errorf("list connectors: %v", err) return @@ -690,21 +690,23 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { } listAndCompare(connectorList) - if err := s.DeleteConnector(c1.ID); err != nil { + if err := s.DeleteConnector(ctx, c1.ID); err != nil { t.Fatalf("failed to delete connector: %v", err) } - if err := s.DeleteConnector(c2.ID); err != nil { + if err := s.DeleteConnector(ctx, c2.ID); err != nil { t.Fatalf("failed to delete connector: %v", err) } - _, err = s.GetConnector(c1.ID) + _, err = s.GetConnector(ctx, c1.ID) mustBeErrNotFound(t, "connector", err) } func testKeysCRUD(t *testing.T, s storage.Storage) { + ctx := context.TODO() + updateAndCompare := func(k storage.Keys) { - err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) { + err := s.UpdateKeys(ctx, func(oldKeys storage.Keys) (storage.Keys, error) { return k, nil }) if err != nil { @@ -712,7 +714,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { return } - if got, err := s.GetKeys(); err != nil { + if got, err := s.GetKeys(ctx); err != nil { t.Errorf("failed to get keys: %v", err) } else { got.NextRotation = got.NextRotation.UTC() @@ -786,24 +788,24 @@ func testGC(t *testing.T, s storage.Storage) { } for _, tz := range []*time.Location{time.UTC, est, pst} { - result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } - if _, err := s.GetAuthCode(c.ID); err != nil { + if _, err := s.GetAuthCode(ctx, c.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } } - if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthCodes != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) } - if _, err := s.GetAuthCode(c.ID); err == nil { + if _, err := s.GetAuthCode(ctx, c.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) @@ -837,24 +839,24 @@ func testGC(t *testing.T, s storage.Storage) { } for _, tz := range []*time.Location{time.UTC, est, pst} { - result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } - if _, err := s.GetAuthRequest(a.ID); err != nil { + if _, err := s.GetAuthRequest(ctx, a.ID); err != nil { t.Errorf("expected to be able to get auth request after GC: %v", err) } } - if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) } - if _, err := s.GetAuthRequest(a.ID); err == nil { + if _, err := s.GetAuthRequest(ctx, a.ID); err == nil { t.Errorf("expected auth request to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) @@ -874,23 +876,23 @@ func testGC(t *testing.T, s storage.Storage) { } for _, tz := range []*time.Location{time.UTC, est, pst} { - result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.DeviceRequests != 0 { t.Errorf("expected no device garbage collection results, got %#v", result) } - if _, err := s.GetDeviceRequest(d.UserCode); err != nil { + if _, err := s.GetDeviceRequest(ctx, d.UserCode); err != nil { t.Errorf("expected to be able to get auth request after GC: %v", err) } } - if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.DeviceRequests != 1 { t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests) } - if _, err := s.GetDeviceRequest(d.UserCode); err == nil { + if _, err := s.GetDeviceRequest(ctx, d.UserCode); err == nil { t.Errorf("expected device request to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) @@ -914,23 +916,23 @@ func testGC(t *testing.T, s storage.Storage) { } for _, tz := range []*time.Location{time.UTC, est, pst} { - result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.DeviceTokens != 0 { t.Errorf("expected no device token garbage collection results, got %#v", result) } - if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil { + if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err != nil { t.Errorf("expected to be able to get device token after GC: %v", err) } } - if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.DeviceTokens != 1 { t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) } - if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil { + if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err == nil { t.Errorf("expected device token to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) @@ -969,7 +971,7 @@ func testTimezones(t *testing.T, s storage.Storage) { if err := s.CreateAuthCode(ctx, c); err != nil { t.Fatalf("failed creating auth code: %v", err) } - got, err := s.GetAuthCode(c.ID) + got, err := s.GetAuthCode(ctx, c.ID) if err != nil { t.Fatalf("failed to get auth code: %v", err) } @@ -1003,7 +1005,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { err := s.CreateDeviceRequest(ctx, d1) mustBeErrAlreadyExists(t, "device request", err) - got, err := s.GetDeviceRequest(d1.UserCode) + got, err := s.GetDeviceRequest(ctx, d1.UserCode) if err != nil { t.Fatalf("failed to get device request: %v", err) } @@ -1041,7 +1043,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { mustBeErrAlreadyExists(t, "device token", err) // Update the device token, simulate a redemption - if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) { + if err := s.UpdateDeviceToken(ctx, d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) { old.Token = "token data" old.Status = "complete" return old, nil @@ -1050,7 +1052,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { } // Retrieve the device token - got, err := s.GetDeviceToken(d1.DeviceCode) + got, err := s.GetDeviceToken(ctx, d1.DeviceCode) if err != nil { t.Fatalf("failed to get device token: %v", err) } diff --git a/storage/conformance/transactions.go b/storage/conformance/transactions.go index 69ed5517ad..60365c9a74 100644 --- a/storage/conformance/transactions.go +++ b/storage/conformance/transactions.go @@ -42,9 +42,9 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { var err1, err2 error - err1 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) { + err1 = s.UpdateClient(ctx, c.ID, func(old storage.Client) (storage.Client, error) { old.Secret = "new secret 1" - err2 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) { + err2 = s.UpdateClient(ctx, c.ID, func(old storage.Client) (storage.Client, error) { old.Secret = "new secret 2" return old, nil }) @@ -87,9 +87,9 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { var err1, err2 error - err1 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { + err1 = s.UpdateAuthRequest(ctx, a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.State = "state 1" - err2 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { + err2 = s.UpdateAuthRequest(ctx, a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.State = "state 2" return old, nil }) @@ -121,9 +121,9 @@ func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { var err1, err2 error - err1 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) { + err1 = s.UpdatePassword(ctx, password.Email, func(old storage.Password) (storage.Password, error) { old.Username = "user 1" - err2 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) { + err2 = s.UpdatePassword(ctx, password.Email, func(old storage.Password) (storage.Password, error) { old.Username = "user 2" return old, nil }) @@ -163,8 +163,9 @@ func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) { var err1, err2 error - err1 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { - err2 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { + ctx := context.TODO() + err1 = s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { + err2 = s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { return keys1, nil }) return keys2, nil diff --git a/storage/ent/client/authcode.go b/storage/ent/client/authcode.go index 8ac1231484..aa5bd184c3 100644 --- a/storage/ent/client/authcode.go +++ b/storage/ent/client/authcode.go @@ -34,8 +34,8 @@ func (d *Database) CreateAuthCode(ctx context.Context, code storage.AuthCode) er } // GetAuthCode extracts an auth code from the database by id. -func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) { - authCode, err := d.client.AuthCode.Get(context.TODO(), id) +func (d *Database) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) { + authCode, err := d.client.AuthCode.Get(ctx, id) if err != nil { return storage.AuthCode{}, convertDBError("get auth code: %w", err) } @@ -43,8 +43,8 @@ func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) { } // DeleteAuthCode deletes an auth code from the database by id. -func (d *Database) DeleteAuthCode(id string) error { - err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO()) +func (d *Database) DeleteAuthCode(ctx context.Context, id string) error { + err := d.client.AuthCode.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete auth code: %w", err) } diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index 42db702d68..25d3e41569 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -40,8 +40,8 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au } // GetAuthRequest extracts an auth request from the database by id. -func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) { - authRequest, err := d.client.AuthRequest.Get(context.TODO(), id) +func (d *Database) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { + authRequest, err := d.client.AuthRequest.Get(ctx, id) if err != nil { return storage.AuthRequest{}, convertDBError("get auth request: %w", err) } @@ -49,8 +49,8 @@ func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) { } // DeleteAuthRequest deletes an auth request from the database by id. -func (d *Database) DeleteAuthRequest(id string) error { - err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO()) +func (d *Database) DeleteAuthRequest(ctx context.Context, id string) error { + err := d.client.AuthRequest.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete auth request: %w", err) } @@ -58,8 +58,8 @@ func (d *Database) DeleteAuthRequest(id string) error { } // UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database. -func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error { - tx, err := d.BeginTx(context.TODO()) +func (d *Database) UpdateAuthRequest(ctx context.Context, id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error { + tx, err := d.BeginTx(ctx) if err != nil { return fmt.Errorf("update auth request tx: %w", err) } diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go index 4cb02c0c83..1957a76a9b 100644 --- a/storage/ent/client/client.go +++ b/storage/ent/client/client.go @@ -24,8 +24,8 @@ func (d *Database) CreateClient(ctx context.Context, client storage.Client) erro } // ListClients extracts an array of oauth2 clients from the database. -func (d *Database) ListClients() ([]storage.Client, error) { - clients, err := d.client.OAuth2Client.Query().All(context.TODO()) +func (d *Database) ListClients(ctx context.Context) ([]storage.Client, error) { + clients, err := d.client.OAuth2Client.Query().All(ctx) if err != nil { return nil, convertDBError("list clients: %w", err) } @@ -38,8 +38,8 @@ func (d *Database) ListClients() ([]storage.Client, error) { } // GetClient extracts an oauth2 client from the database by id. -func (d *Database) GetClient(id string) (storage.Client, error) { - client, err := d.client.OAuth2Client.Get(context.TODO(), id) +func (d *Database) GetClient(ctx context.Context, id string) (storage.Client, error) { + client, err := d.client.OAuth2Client.Get(ctx, id) if err != nil { return storage.Client{}, convertDBError("get client: %w", err) } @@ -47,8 +47,8 @@ func (d *Database) GetClient(id string) (storage.Client, error) { } // DeleteClient deletes an oauth2 client from the database by id. -func (d *Database) DeleteClient(id string) error { - err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO()) +func (d *Database) DeleteClient(ctx context.Context, id string) error { + err := d.client.OAuth2Client.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete client: %w", err) } @@ -56,13 +56,13 @@ func (d *Database) DeleteClient(id string) error { } // UpdateClient changes an oauth2 client by id using an updater function and saves it to the database. -func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - tx, err := d.BeginTx(context.TODO()) +func (d *Database) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update client tx: %w", err) } - client, err := tx.OAuth2Client.Get(context.TODO(), id) + client, err := tx.OAuth2Client.Get(ctx, id) if err != nil { return rollback(tx, "update client database: %w", err) } @@ -79,7 +79,7 @@ func (d *Database) UpdateClient(id string, updater func(old storage.Client) (sto SetLogoURL(newClient.LogoURL). SetRedirectUris(newClient.RedirectURIs). SetTrustedPeers(newClient.TrustedPeers). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update client uploading: %w", err) } diff --git a/storage/ent/client/connector.go b/storage/ent/client/connector.go index 1534e52241..f0cff8ba6a 100644 --- a/storage/ent/client/connector.go +++ b/storage/ent/client/connector.go @@ -22,8 +22,8 @@ func (d *Database) CreateConnector(ctx context.Context, connector storage.Connec } // ListConnectors extracts an array of connectors from the database. -func (d *Database) ListConnectors() ([]storage.Connector, error) { - connectors, err := d.client.Connector.Query().All(context.TODO()) +func (d *Database) ListConnectors(ctx context.Context) ([]storage.Connector, error) { + connectors, err := d.client.Connector.Query().All(ctx) if err != nil { return nil, convertDBError("list connectors: %w", err) } @@ -36,8 +36,8 @@ func (d *Database) ListConnectors() ([]storage.Connector, error) { } // GetConnector extracts a connector from the database by id. -func (d *Database) GetConnector(id string) (storage.Connector, error) { - connector, err := d.client.Connector.Get(context.TODO(), id) +func (d *Database) GetConnector(ctx context.Context, id string) (storage.Connector, error) { + connector, err := d.client.Connector.Get(ctx, id) if err != nil { return storage.Connector{}, convertDBError("get connector: %w", err) } @@ -45,8 +45,8 @@ func (d *Database) GetConnector(id string) (storage.Connector, error) { } // DeleteConnector deletes a connector from the database by id. -func (d *Database) DeleteConnector(id string) error { - err := d.client.Connector.DeleteOneID(id).Exec(context.TODO()) +func (d *Database) DeleteConnector(ctx context.Context, id string) error { + err := d.client.Connector.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete connector: %w", err) } @@ -54,13 +54,13 @@ func (d *Database) DeleteConnector(id string) error { } // UpdateConnector changes a connector by id using an updater function and saves it to the database. -func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error { - tx, err := d.BeginTx(context.TODO()) +func (d *Database) UpdateConnector(ctx context.Context, id string, updater func(old storage.Connector) (storage.Connector, error)) error { + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update connector tx: %w", err) } - connector, err := tx.Connector.Get(context.TODO(), id) + connector, err := tx.Connector.Get(ctx, id) if err != nil { return rollback(tx, "update connector database: %w", err) } @@ -75,7 +75,7 @@ func (d *Database) UpdateConnector(id string, updater func(old storage.Connector SetType(newConnector.Type). SetResourceVersion(newConnector.ResourceVersion). SetConfig(newConnector.Config). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update connector uploading: %w", err) } diff --git a/storage/ent/client/devicerequest.go b/storage/ent/client/devicerequest.go index d8d371c9ba..5673395567 100644 --- a/storage/ent/client/devicerequest.go +++ b/storage/ent/client/devicerequest.go @@ -25,10 +25,10 @@ func (d *Database) CreateDeviceRequest(ctx context.Context, request storage.Devi } // GetDeviceRequest extracts a device request from the database by user code. -func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { +func (d *Database) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { deviceRequest, err := d.client.DeviceRequest.Query(). Where(devicerequest.UserCode(userCode)). - Only(context.TODO()) + Only(ctx) if err != nil { return storage.DeviceRequest{}, convertDBError("get device request: %w", err) } diff --git a/storage/ent/client/devicetoken.go b/storage/ent/client/devicetoken.go index 18d483b98a..759812b196 100644 --- a/storage/ent/client/devicetoken.go +++ b/storage/ent/client/devicetoken.go @@ -27,10 +27,10 @@ func (d *Database) CreateDeviceToken(ctx context.Context, token storage.DeviceTo } // GetDeviceToken extracts a token from the database by device code. -func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { +func (d *Database) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { deviceToken, err := d.client.DeviceToken.Query(). Where(devicetoken.DeviceCode(deviceCode)). - Only(context.TODO()) + Only(ctx) if err != nil { return storage.DeviceToken{}, convertDBError("get device token: %w", err) } @@ -38,15 +38,15 @@ func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error } // UpdateDeviceToken changes a token by device code using an updater function and saves it to the database. -func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { - tx, err := d.BeginTx(context.TODO()) +func (d *Database) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update device token tx: %w", err) } token, err := tx.DeviceToken.Query(). Where(devicetoken.DeviceCode(deviceCode)). - Only(context.TODO()) + Only(ctx) if err != nil { return rollback(tx, "update device token database: %w", err) } @@ -67,7 +67,7 @@ func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage SetStatus(newToken.Status). SetCodeChallenge(newToken.PKCE.CodeChallenge). SetCodeChallengeMethod(newToken.PKCE.CodeChallengeMethod). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update device token uploading: %w", err) } diff --git a/storage/ent/client/keys.go b/storage/ent/client/keys.go index f65d40fc21..c4e972026f 100644 --- a/storage/ent/client/keys.go +++ b/storage/ent/client/keys.go @@ -8,8 +8,8 @@ import ( "github.com/dexidp/dex/storage/ent/db" ) -func getKeys(client *db.KeysClient) (storage.Keys, error) { - rawKeys, err := client.Get(context.TODO(), keysRowID) +func getKeys(ctx context.Context, client *db.KeysClient) (storage.Keys, error) { + rawKeys, err := client.Get(ctx, keysRowID) if err != nil { return storage.Keys{}, convertDBError("get keys: %w", err) } @@ -18,20 +18,20 @@ func getKeys(client *db.KeysClient) (storage.Keys, error) { } // GetKeys returns signing keys, public keys and verification keys from the database. -func (d *Database) GetKeys() (storage.Keys, error) { - return getKeys(d.client.Keys) +func (d *Database) GetKeys(ctx context.Context) (storage.Keys, error) { + return getKeys(ctx, d.client.Keys) } // UpdateKeys rotates keys using updater function. -func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { +func (d *Database) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { firstUpdate := false - tx, err := d.BeginTx(context.TODO()) + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update keys tx: %w", err) } - storageKeys, err := getKeys(tx.Keys) + storageKeys, err := getKeys(ctx, tx.Keys) if err != nil { if !errors.Is(err, storage.ErrNotFound) { return rollback(tx, "update keys get: %w", err) @@ -53,7 +53,7 @@ func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro SetSigningKey(*newKeys.SigningKey). SetSigningKeyPub(*newKeys.SigningKeyPub). SetVerificationKeys(newKeys.VerificationKeys). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "create keys: %w", err) } @@ -68,7 +68,7 @@ func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro SetSigningKey(*newKeys.SigningKey). SetSigningKeyPub(*newKeys.SigningKeyPub). SetVerificationKeys(newKeys.VerificationKeys). - Exec(context.TODO()) + Exec(ctx) if err != nil { return rollback(tx, "update keys uploading: %w", err) } diff --git a/storage/ent/client/main.go b/storage/ent/client/main.go index bc4c1600ac..a78830fc76 100644 --- a/storage/ent/client/main.go +++ b/storage/ent/client/main.go @@ -70,13 +70,13 @@ func (d *Database) BeginTx(ctx context.Context) (*db.Tx, error) { } // GarbageCollect removes expired entities from the database. -func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { +func (d *Database) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) { result := storage.GCResult{} utcNow := now.UTC() q, err := d.client.AuthRequest.Delete(). Where(authrequest.ExpiryLT(utcNow)). - Exec(context.TODO()) + Exec(ctx) if err != nil { return result, convertDBError("gc auth request: %w", err) } @@ -84,7 +84,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { q, err = d.client.AuthCode.Delete(). Where(authcode.ExpiryLT(utcNow)). - Exec(context.TODO()) + Exec(ctx) if err != nil { return result, convertDBError("gc auth code: %w", err) } @@ -92,7 +92,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { q, err = d.client.DeviceRequest.Delete(). Where(devicerequest.ExpiryLT(utcNow)). - Exec(context.TODO()) + Exec(ctx) if err != nil { return result, convertDBError("gc device request: %w", err) } @@ -100,7 +100,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { q, err = d.client.DeviceToken.Delete(). Where(devicetoken.ExpiryLT(utcNow)). - Exec(context.TODO()) + Exec(ctx) if err != nil { return result, convertDBError("gc device token: %w", err) } diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go index 22469eced9..9d608cb6f3 100644 --- a/storage/ent/client/offlinesession.go +++ b/storage/ent/client/offlinesession.go @@ -30,10 +30,10 @@ func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.Of } // GetOfflineSessions extracts an offline session from the database by user id and connector id. -func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) { +func (d *Database) GetOfflineSessions(ctx context.Context, userID, connID string) (storage.OfflineSessions, error) { id := offlineSessionID(userID, connID, d.hasher) - offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id) + offlineSession, err := d.client.OfflineSession.Get(ctx, id) if err != nil { return storage.OfflineSessions{}, convertDBError("get offline session: %w", err) } @@ -41,10 +41,10 @@ func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSes } // DeleteOfflineSessions deletes an offline session from the database by user id and connector id. -func (d *Database) DeleteOfflineSessions(userID, connID string) error { +func (d *Database) DeleteOfflineSessions(ctx context.Context, userID, connID string) error { id := offlineSessionID(userID, connID, d.hasher) - err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO()) + err := d.client.OfflineSession.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete offline session: %w", err) } @@ -52,15 +52,15 @@ func (d *Database) DeleteOfflineSessions(userID, connID string) error { } // UpdateOfflineSessions changes an offline session by user id and connector id using an updater function. -func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { +func (d *Database) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { id := offlineSessionID(userID, connID, d.hasher) - tx, err := d.BeginTx(context.TODO()) + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update offline session tx: %w", err) } - offlineSession, err := tx.OfflineSession.Get(context.TODO(), id) + offlineSession, err := tx.OfflineSession.Get(ctx, id) if err != nil { return rollback(tx, "update offline session database: %w", err) } @@ -80,7 +80,7 @@ func (d *Database) UpdateOfflineSessions(userID string, connID string, updater f SetConnID(newOfflineSession.ConnID). SetConnectorData(newOfflineSession.ConnectorData). SetRefresh(encodedRefresh). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update offline session uploading: %w", err) } diff --git a/storage/ent/client/password.go b/storage/ent/client/password.go index 3e4aace8ae..2845fa8f76 100644 --- a/storage/ent/client/password.go +++ b/storage/ent/client/password.go @@ -23,8 +23,8 @@ func (d *Database) CreatePassword(ctx context.Context, password storage.Password } // ListPasswords extracts an array of passwords from the database. -func (d *Database) ListPasswords() ([]storage.Password, error) { - passwords, err := d.client.Password.Query().All(context.TODO()) +func (d *Database) ListPasswords(ctx context.Context) ([]storage.Password, error) { + passwords, err := d.client.Password.Query().All(ctx) if err != nil { return nil, convertDBError("list passwords: %w", err) } @@ -37,11 +37,11 @@ func (d *Database) ListPasswords() ([]storage.Password, error) { } // GetPassword extracts a password from the database by email. -func (d *Database) GetPassword(email string) (storage.Password, error) { +func (d *Database) GetPassword(ctx context.Context, email string) (storage.Password, error) { email = strings.ToLower(email) passwordFromStorage, err := d.client.Password.Query(). Where(password.Email(email)). - Only(context.TODO()) + Only(ctx) if err != nil { return storage.Password{}, convertDBError("get password: %w", err) } @@ -49,11 +49,11 @@ func (d *Database) GetPassword(email string) (storage.Password, error) { } // DeletePassword deletes a password from the database by email. -func (d *Database) DeletePassword(email string) error { +func (d *Database) DeletePassword(ctx context.Context, email string) error { email = strings.ToLower(email) _, err := d.client.Password.Delete(). Where(password.Email(email)). - Exec(context.TODO()) + Exec(ctx) if err != nil { return convertDBError("delete password: %w", err) } @@ -61,17 +61,17 @@ func (d *Database) DeletePassword(email string) error { } // UpdatePassword changes a password by email using an updater function and saves it to the database. -func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error { +func (d *Database) UpdatePassword(ctx context.Context, email string, updater func(old storage.Password) (storage.Password, error)) error { email = strings.ToLower(email) - tx, err := d.BeginTx(context.TODO()) + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update connector tx: %w", err) } passwordToUpdate, err := tx.Password.Query(). Where(password.Email(email)). - Only(context.TODO()) + Only(ctx) if err != nil { return rollback(tx, "update password database: %w", err) } @@ -87,7 +87,7 @@ func (d *Database) UpdatePassword(email string, updater func(old storage.Passwor SetHash(newPassword.Hash). SetUsername(newPassword.Username). SetUserID(newPassword.UserID). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update password uploading: %w", err) } diff --git a/storage/ent/client/refreshtoken.go b/storage/ent/client/refreshtoken.go index 6861b07916..d423565439 100644 --- a/storage/ent/client/refreshtoken.go +++ b/storage/ent/client/refreshtoken.go @@ -34,8 +34,8 @@ func (d *Database) CreateRefresh(ctx context.Context, refresh storage.RefreshTok } // ListRefreshTokens extracts an array of refresh tokens from the database. -func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) { - refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO()) +func (d *Database) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) { + refreshTokens, err := d.client.RefreshToken.Query().All(ctx) if err != nil { return nil, convertDBError("list refresh tokens: %w", err) } @@ -48,8 +48,8 @@ func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) { } // GetRefresh extracts a refresh token from the database by id. -func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) { - refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id) +func (d *Database) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { + refreshToken, err := d.client.RefreshToken.Get(ctx, id) if err != nil { return storage.RefreshToken{}, convertDBError("get refresh token: %w", err) } @@ -57,8 +57,8 @@ func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) { } // DeleteRefresh deletes a refresh token from the database by id. -func (d *Database) DeleteRefresh(id string) error { - err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO()) +func (d *Database) DeleteRefresh(ctx context.Context, id string) error { + err := d.client.RefreshToken.DeleteOneID(id).Exec(ctx) if err != nil { return convertDBError("delete refresh token: %w", err) } @@ -66,13 +66,13 @@ func (d *Database) DeleteRefresh(id string) error { } // UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database. -func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - tx, err := d.BeginTx(context.TODO()) +func (d *Database) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + tx, err := d.BeginTx(ctx) if err != nil { return convertDBError("update refresh token tx: %w", err) } - token, err := tx.RefreshToken.Get(context.TODO(), id) + token, err := tx.RefreshToken.Get(ctx, id) if err != nil { return rollback(tx, "update refresh token database: %w", err) } @@ -99,7 +99,7 @@ func (d *Database) UpdateRefreshToken(id string, updater func(old storage.Refres // Save utc time into database because ent doesn't support comparing dates with different timezones SetLastUsed(newtToken.LastUsed.UTC()). SetCreatedAt(newtToken.CreatedAt.UTC()). - Save(context.TODO()) + Save(ctx) if err != nil { return rollback(tx, "update refresh token uploading: %w", err) } diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index f65701ff1f..8ccf502f2e 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -40,8 +40,8 @@ func (c *conn) Close() error { return c.db.Close() } -func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() authRequests, err := c.listAuthRequests(ctx) if err != nil { @@ -113,8 +113,9 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a)) } -func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetAuthRequest(ctx context.Context, id string) (a storage.AuthRequest, err error) { + // TODO: Add this to other funcs?? + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var req AuthRequest if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil { @@ -123,8 +124,8 @@ func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) { return toStorageAuthRequest(req), nil } -func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) { var current AuthRequest @@ -141,8 +142,8 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) }) } -func (c *conn) DeleteAuthRequest(id string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyID(authRequestPrefix, id)) } @@ -151,8 +152,8 @@ func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error { return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a)) } -func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetAuthCode(ctx context.Context, id string) (a storage.AuthCode, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var ac AuthCode err = c.getKey(ctx, keyID(authCodePrefix, id), &ac) @@ -162,8 +163,8 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { return a, err } -func (c *conn) DeleteAuthCode(id string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteAuthCode(ctx context.Context, id string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyID(authCodePrefix, id)) } @@ -172,8 +173,8 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r)) } -func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetRefresh(ctx context.Context, id string) (r storage.RefreshToken, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var token RefreshToken if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil { @@ -182,8 +183,8 @@ func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) { return toStorageRefreshToken(token), nil } -func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) { var current RefreshToken @@ -200,14 +201,14 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok }) } -func (c *conn) DeleteRefresh(id string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteRefresh(ctx context.Context, id string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyID(refreshTokenPrefix, id)) } -func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) ListRefreshTokens(ctx context.Context) (tokens []storage.RefreshToken, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix()) if err != nil { @@ -227,15 +228,15 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli) } -func (c *conn) GetClient(id string) (cli storage.Client, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetClient(ctx context.Context, id string) (cli storage.Client, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() err = c.getKey(ctx, keyID(clientPrefix, id), &cli) return cli, err } -func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) { var current storage.Client @@ -252,14 +253,14 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage }) } -func (c *conn) DeleteClient(id string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteClient(ctx context.Context, id string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyID(clientPrefix, id)) } -func (c *conn) ListClients() (clients []storage.Client, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) ListClients(ctx context.Context) (clients []storage.Client, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix()) if err != nil { @@ -279,15 +280,15 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error { return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p) } -func (c *conn) GetPassword(email string) (p storage.Password, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetPassword(ctx context.Context, email string) (p storage.Password, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p) return p, err } -func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) { var current storage.Password @@ -304,14 +305,14 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st }) } -func (c *conn) DeletePassword(email string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeletePassword(ctx context.Context, email string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyEmail(passwordPrefix, email)) } -func (c *conn) ListPasswords() (passwords []storage.Password, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix()) if err != nil { @@ -331,8 +332,8 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi return c.txnCreate(ctx, keySession(s.UserID, s.ConnID), fromStorageOfflineSessions(s)) } -func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keySession(userID, connID), func(currentValue []byte) ([]byte, error) { var current OfflineSessions @@ -349,8 +350,8 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( }) } -func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (s storage.OfflineSessions, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var os OfflineSessions if err = c.getKey(ctx, keySession(userID, connID), &os); err != nil { @@ -359,8 +360,8 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.Offli return toStorageOfflineSessions(os), nil } -func (c *conn) DeleteOfflineSessions(userID string, connID string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keySession(userID, connID)) } @@ -369,15 +370,15 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) } -func (c *conn) GetConnector(id string) (conn storage.Connector, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetConnector(ctx context.Context, id string) (conn storage.Connector, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() err = c.getKey(ctx, keyID(connectorPrefix, id), &conn) return conn, err } -func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) { var current storage.Connector @@ -394,14 +395,14 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto }) } -func (c *conn) DeleteConnector(id string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) DeleteConnector(ctx context.Context, id string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.deleteKey(ctx, keyID(connectorPrefix, id)) } -func (c *conn) ListConnectors() (connectors []storage.Connector, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) ListConnectors(ctx context.Context) (connectors []storage.Connector, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix()) if err != nil { @@ -417,8 +418,8 @@ func (c *conn) ListConnectors() (connectors []storage.Connector, err error) { return connectors, nil } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() res, err := c.db.Get(ctx, keysName) if err != nil { @@ -430,8 +431,8 @@ func (c *conn) GetKeys() (keys storage.Keys, err error) { return keys, err } -func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) { var current storage.Keys @@ -560,8 +561,8 @@ func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) } -func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (r storage.DeviceRequest, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var dr DeviceRequest if err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &dr); err == nil { @@ -589,8 +590,8 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t)) } -func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (t storage.DeviceToken, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() var dt DeviceToken if err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &dt); err == nil { @@ -614,8 +615,8 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken return deviceTokens, nil } -func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) +func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) { var current DeviceToken diff --git a/storage/health.go b/storage/health.go index 8cdefddf32..fe75a2da74 100644 --- a/storage/health.go +++ b/storage/health.go @@ -23,7 +23,7 @@ func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Cont return nil, fmt.Errorf("create auth request: %v", err) } - if err := s.DeleteAuthRequest(a.ID); err != nil { + if err := s.DeleteAuthRequest(ctx, a.ID); err != nil { return nil, fmt.Errorf("delete auth request: %v", err) } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 8b6d5c9c2e..6ff522837c 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -262,7 +262,7 @@ func (cli *client) CreateConnector(ctx context.Context, c storage.Connector) err return cli.post(resourceConnector, cli.fromStorageConnector(c)) } -func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { +func (cli *client) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { var req AuthRequest if err := cli.get(resourceAuthRequest, id, &req); err != nil { return storage.AuthRequest{}, err @@ -270,7 +270,7 @@ func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { return toStorageAuthRequest(req), nil } -func (cli *client) GetAuthCode(id string) (storage.AuthCode, error) { +func (cli *client) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) { var code AuthCode if err := cli.get(resourceAuthCode, id, &code); err != nil { return storage.AuthCode{}, err @@ -278,7 +278,7 @@ func (cli *client) GetAuthCode(id string) (storage.AuthCode, error) { return toStorageAuthCode(code), nil } -func (cli *client) GetClient(id string) (storage.Client, error) { +func (cli *client) GetClient(ctx context.Context, id string) (storage.Client, error) { c, err := cli.getClient(id) if err != nil { return storage.Client{}, err @@ -298,7 +298,7 @@ func (cli *client) getClient(id string) (Client, error) { return c, nil } -func (cli *client) GetPassword(email string) (storage.Password, error) { +func (cli *client) GetPassword(ctx context.Context, email string) (storage.Password, error) { p, err := cli.getPassword(email) if err != nil { return storage.Password{}, err @@ -320,7 +320,7 @@ func (cli *client) getPassword(email string) (Password, error) { return p, nil } -func (cli *client) GetKeys() (storage.Keys, error) { +func (cli *client) GetKeys(ctx context.Context) (storage.Keys, error) { var keys Keys if err := cli.get(resourceKeys, keysName, &keys); err != nil { return storage.Keys{}, err @@ -328,7 +328,7 @@ func (cli *client) GetKeys() (storage.Keys, error) { return toStorageKeys(keys), nil } -func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { +func (cli *client) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { r, err := cli.getRefreshToken(id) if err != nil { return storage.RefreshToken{}, err @@ -341,7 +341,7 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { return } -func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { +func (cli *client) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) { o, err := cli.getOfflineSessions(userID, connID) if err != nil { return storage.OfflineSessions{}, err @@ -360,7 +360,7 @@ func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSe return o, nil } -func (cli *client) GetConnector(id string) (storage.Connector, error) { +func (cli *client) GetConnector(ctx context.Context, id string) (storage.Connector, error) { var c Connector if err := cli.get(resourceConnector, id, &c); err != nil { return storage.Connector{}, err @@ -368,15 +368,15 @@ func (cli *client) GetConnector(id string) (storage.Connector, error) { return toStorageConnector(c), nil } -func (cli *client) ListClients() ([]storage.Client, error) { +func (cli *client) ListClients(ctx context.Context) ([]storage.Client, error) { return nil, errors.New("not implemented") } -func (cli *client) ListRefreshTokens() ([]storage.RefreshToken, error) { +func (cli *client) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) { return nil, errors.New("not implemented") } -func (cli *client) ListPasswords() (passwords []storage.Password, err error) { +func (cli *client) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) { var passwordList PasswordList if err = cli.list(resourcePassword, &passwordList); err != nil { return passwords, fmt.Errorf("failed to list passwords: %v", err) @@ -395,7 +395,7 @@ func (cli *client) ListPasswords() (passwords []storage.Password, err error) { return } -func (cli *client) ListConnectors() (connectors []storage.Connector, err error) { +func (cli *client) ListConnectors(ctx context.Context) (connectors []storage.Connector, err error) { var connectorList ConnectorList if err = cli.list(resourceConnector, &connectorList); err != nil { return connectors, fmt.Errorf("failed to list connectors: %v", err) @@ -409,15 +409,15 @@ func (cli *client) ListConnectors() (connectors []storage.Connector, err error) return } -func (cli *client) DeleteAuthRequest(id string) error { +func (cli *client) DeleteAuthRequest(ctx context.Context, id string) error { return cli.delete(resourceAuthRequest, id) } -func (cli *client) DeleteAuthCode(code string) error { +func (cli *client) DeleteAuthCode(ctx context.Context, code string) error { return cli.delete(resourceAuthCode, code) } -func (cli *client) DeleteClient(id string) error { +func (cli *client) DeleteClient(ctx context.Context, id string) error { // Check for hash collision. c, err := cli.getClient(id) if err != nil { @@ -426,11 +426,11 @@ func (cli *client) DeleteClient(id string) error { return cli.delete(resourceClient, c.ObjectMeta.Name) } -func (cli *client) DeleteRefresh(id string) error { +func (cli *client) DeleteRefresh(ctx context.Context, id string) error { return cli.delete(resourceRefreshToken, id) } -func (cli *client) DeletePassword(email string) error { +func (cli *client) DeletePassword(ctx context.Context, email string) error { // Check for hash collision. p, err := cli.getPassword(email) if err != nil { @@ -439,7 +439,7 @@ func (cli *client) DeletePassword(email string) error { return cli.delete(resourcePassword, p.ObjectMeta.Name) } -func (cli *client) DeleteOfflineSessions(userID string, connID string) error { +func (cli *client) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { // Check for hash collision. o, err := cli.getOfflineSessions(userID, connID) if err != nil { @@ -448,11 +448,11 @@ func (cli *client) DeleteOfflineSessions(userID string, connID string) error { return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name) } -func (cli *client) DeleteConnector(id string) error { +func (cli *client) DeleteConnector(ctx context.Context, id string) error { return cli.delete(resourceConnector, id) } -func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { +func (cli *client) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { lock := newRefreshTokenLock(cli) if err := lock.Lock(id); err != nil { @@ -460,7 +460,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres } defer lock.Unlock(id) - return retryOnConflict(context.TODO(), func() error { + return retryOnConflict(ctx, func() error { r, err := cli.getRefreshToken(id) if err != nil { return err @@ -479,7 +479,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres }) } -func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { +func (cli *client) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { c, err := cli.getClient(id) if err != nil { return err @@ -496,7 +496,7 @@ func (cli *client) UpdateClient(id string, updater func(old storage.Client) (sto return cli.put(resourceClient, c.ObjectMeta.Name, newClient) } -func (cli *client) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error { +func (cli *client) UpdatePassword(ctx context.Context, email string, updater func(old storage.Password) (storage.Password, error)) error { p, err := cli.getPassword(email) if err != nil { return err @@ -513,8 +513,8 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword) } -func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error { - return retryOnConflict(context.TODO(), func() error { +func (cli *client) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error { + return retryOnConflict(ctx, func() error { o, err := cli.getOfflineSessions(userID, connID) if err != nil { return err @@ -531,7 +531,7 @@ func (cli *client) UpdateOfflineSessions(userID string, connID string, updater f }) } -func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { +func (cli *client) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { firstUpdate := false var keys Keys if err := cli.get(resourceKeys, keysName, &keys); err != nil { @@ -576,7 +576,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro return err } -func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { +func (cli *client) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { var req AuthRequest err := cli.get(resourceAuthRequest, id, &req) if err != nil { @@ -593,8 +593,8 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque return cli.put(resourceAuthRequest, id, newReq) } -func (cli *client) UpdateConnector(id string, updater func(a storage.Connector) (storage.Connector, error)) error { - return retryOnConflict(context.TODO(), func() error { +func (cli *client) UpdateConnector(ctx context.Context, id string, updater func(a storage.Connector) (storage.Connector, error)) error { + return retryOnConflict(ctx, func() error { var c Connector err := cli.get(resourceConnector, id, &c) if err != nil { @@ -612,7 +612,7 @@ func (cli *client) UpdateConnector(id string, updater func(a storage.Connector) }) } -func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) { +func (cli *client) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) { var authRequests AuthRequestList if err := cli.listN(resourceAuthRequest, &authRequests, gcResultLimit); err != nil { return result, fmt.Errorf("failed to list auth requests: %v", err) @@ -687,7 +687,7 @@ func (cli *client) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequ return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) } -func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { +func (cli *client) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { var req DeviceRequest if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil { return storage.DeviceRequest{}, err @@ -699,7 +699,7 @@ func (cli *client) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) } -func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { +func (cli *client) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { var token DeviceToken if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil { return storage.DeviceToken{}, err @@ -712,8 +712,8 @@ func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) return } -func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { - return retryOnConflict(context.TODO(), func() error { +func (cli *client) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + return retryOnConflict(ctx, func() error { r, err := cli.getDeviceToken(deviceCode) if err != nil { return err diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index d8bfd1f689..525bd904cf 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -221,7 +221,7 @@ func TestUpdateKeys(t *testing.T) { for _, test := range tests { client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode) - err := client.UpdateKeys(test.updater) + err := client.UpdateKeys(context.TODO(), test.updater) if err != nil { if !test.wantErr { t.Fatalf("Test %q: %v", test.name, err) @@ -339,9 +339,9 @@ func TestRefreshTokenLock(t *testing.T) { require.NoError(t, err) t.Run("Timeout lock error", func(t *testing.T) { - err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + err = kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { r.Token = "update-result-1" - err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + err := kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { r.Token = "timeout-err" return r, nil }) @@ -350,7 +350,7 @@ func TestRefreshTokenLock(t *testing.T) { }) require.NoError(t, err) - token, err := kubeClient.GetRefresh(r.ID) + token, err := kubeClient.GetRefresh(context.TODO(), r.ID) require.NoError(t, err) require.Equal(t, "update-result-1", token.Token) }) @@ -359,13 +359,13 @@ func TestRefreshTokenLock(t *testing.T) { var lockBroken bool lockTimeout = -time.Hour - err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + err = kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { r.Token = "update-result-2" if lockBroken { return r, nil } - err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + err := kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { r.Token = "should-break-the-lock-and-finish-updating" return r, nil }) @@ -376,7 +376,7 @@ func TestRefreshTokenLock(t *testing.T) { }) require.NoError(t, err) - token, err := kubeClient.GetRefresh(r.ID) + token, err := kubeClient.GetRefresh(context.TODO(), r.ID) require.NoError(t, err) // Because concurrent update breaks the lock, the final result will be the value of the first update require.Equal(t, "update-result-2", token.Token) diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 4399c61df1..eff75e716d 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -71,7 +71,7 @@ func (s *memStorage) tx(f func()) { func (s *memStorage) Close() error { return nil } -func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) { +func (s *memStorage) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) { s.tx(func() { for id, a := range s.authCodes { if now.After(a.Expiry) { @@ -183,7 +183,7 @@ func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Conn return } -func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { +func (s *memStorage) GetAuthCode(ctx context.Context, id string) (c storage.AuthCode, err error) { s.tx(func() { var ok bool if c, ok = s.authCodes[id]; !ok { @@ -194,7 +194,7 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { return } -func (s *memStorage) GetPassword(email string) (p storage.Password, err error) { +func (s *memStorage) GetPassword(ctx context.Context, email string) (p storage.Password, err error) { email = strings.ToLower(email) s.tx(func() { var ok bool @@ -205,7 +205,7 @@ func (s *memStorage) GetPassword(email string) (p storage.Password, err error) { return } -func (s *memStorage) GetClient(id string) (client storage.Client, err error) { +func (s *memStorage) GetClient(ctx context.Context, id string) (client storage.Client, err error) { s.tx(func() { var ok bool if client, ok = s.clients[id]; !ok { @@ -215,12 +215,12 @@ func (s *memStorage) GetClient(id string) (client storage.Client, err error) { return } -func (s *memStorage) GetKeys() (keys storage.Keys, err error) { +func (s *memStorage) GetKeys(ctx context.Context) (keys storage.Keys, err error) { s.tx(func() { keys = s.keys }) return } -func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) { +func (s *memStorage) GetRefresh(ctx context.Context, id string) (tok storage.RefreshToken, err error) { s.tx(func() { var ok bool if tok, ok = s.refreshTokens[id]; !ok { @@ -231,7 +231,7 @@ func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) return } -func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err error) { +func (s *memStorage) GetAuthRequest(ctx context.Context, id string) (req storage.AuthRequest, err error) { s.tx(func() { var ok bool if req, ok = s.authReqs[id]; !ok { @@ -242,7 +242,7 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err return } -func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) { +func (s *memStorage) GetOfflineSessions(ctx context.Context, userID string, connID string) (o storage.OfflineSessions, err error) { id := offlineSessionID{ userID: userID, connID: connID, @@ -257,7 +257,7 @@ func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage return } -func (s *memStorage) GetConnector(id string) (connector storage.Connector, err error) { +func (s *memStorage) GetConnector(ctx context.Context, id string) (connector storage.Connector, err error) { s.tx(func() { var ok bool if connector, ok = s.connectors[id]; !ok { @@ -267,7 +267,7 @@ func (s *memStorage) GetConnector(id string) (connector storage.Connector, err e return } -func (s *memStorage) ListClients() (clients []storage.Client, err error) { +func (s *memStorage) ListClients(ctx context.Context) (clients []storage.Client, err error) { s.tx(func() { for _, client := range s.clients { clients = append(clients, client) @@ -276,7 +276,7 @@ func (s *memStorage) ListClients() (clients []storage.Client, err error) { return } -func (s *memStorage) ListRefreshTokens() (tokens []storage.RefreshToken, err error) { +func (s *memStorage) ListRefreshTokens(ctx context.Context) (tokens []storage.RefreshToken, err error) { s.tx(func() { for _, refresh := range s.refreshTokens { tokens = append(tokens, refresh) @@ -285,7 +285,7 @@ func (s *memStorage) ListRefreshTokens() (tokens []storage.RefreshToken, err err return } -func (s *memStorage) ListPasswords() (passwords []storage.Password, err error) { +func (s *memStorage) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) { s.tx(func() { for _, password := range s.passwords { passwords = append(passwords, password) @@ -294,7 +294,7 @@ func (s *memStorage) ListPasswords() (passwords []storage.Password, err error) { return } -func (s *memStorage) ListConnectors() (conns []storage.Connector, err error) { +func (s *memStorage) ListConnectors(ctx context.Context) (conns []storage.Connector, err error) { s.tx(func() { for _, c := range s.connectors { conns = append(conns, c) @@ -303,7 +303,7 @@ func (s *memStorage) ListConnectors() (conns []storage.Connector, err error) { return } -func (s *memStorage) DeletePassword(email string) (err error) { +func (s *memStorage) DeletePassword(ctx context.Context, email string) (err error) { email = strings.ToLower(email) s.tx(func() { if _, ok := s.passwords[email]; !ok { @@ -315,7 +315,7 @@ func (s *memStorage) DeletePassword(email string) (err error) { return } -func (s *memStorage) DeleteClient(id string) (err error) { +func (s *memStorage) DeleteClient(ctx context.Context, id string) (err error) { s.tx(func() { if _, ok := s.clients[id]; !ok { err = storage.ErrNotFound @@ -326,7 +326,7 @@ func (s *memStorage) DeleteClient(id string) (err error) { return } -func (s *memStorage) DeleteRefresh(id string) (err error) { +func (s *memStorage) DeleteRefresh(ctx context.Context, id string) (err error) { s.tx(func() { if _, ok := s.refreshTokens[id]; !ok { err = storage.ErrNotFound @@ -337,7 +337,7 @@ func (s *memStorage) DeleteRefresh(id string) (err error) { return } -func (s *memStorage) DeleteAuthCode(id string) (err error) { +func (s *memStorage) DeleteAuthCode(ctx context.Context, id string) (err error) { s.tx(func() { if _, ok := s.authCodes[id]; !ok { err = storage.ErrNotFound @@ -348,7 +348,7 @@ func (s *memStorage) DeleteAuthCode(id string) (err error) { return } -func (s *memStorage) DeleteAuthRequest(id string) (err error) { +func (s *memStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) { s.tx(func() { if _, ok := s.authReqs[id]; !ok { err = storage.ErrNotFound @@ -359,7 +359,7 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) { return } -func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) { +func (s *memStorage) DeleteOfflineSessions(ctx context.Context, userID string, connID string) (err error) { id := offlineSessionID{ userID: userID, connID: connID, @@ -374,7 +374,7 @@ func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err er return } -func (s *memStorage) DeleteConnector(id string) (err error) { +func (s *memStorage) DeleteConnector(ctx context.Context, id string) (err error) { s.tx(func() { if _, ok := s.connectors[id]; !ok { err = storage.ErrNotFound @@ -385,7 +385,7 @@ func (s *memStorage) DeleteConnector(id string) (err error) { return } -func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) { +func (s *memStorage) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) (err error) { s.tx(func() { client, ok := s.clients[id] if !ok { @@ -399,7 +399,7 @@ func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (s return } -func (s *memStorage) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) (err error) { +func (s *memStorage) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) (err error) { s.tx(func() { var keys storage.Keys if keys, err = updater(s.keys); err == nil { @@ -409,7 +409,7 @@ func (s *memStorage) UpdateKeys(updater func(old storage.Keys) (storage.Keys, er return } -func (s *memStorage) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) (err error) { +func (s *memStorage) UpdateAuthRequest(ctx context.Context, id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) (err error) { s.tx(func() { req, ok := s.authReqs[id] if !ok { @@ -423,7 +423,7 @@ func (s *memStorage) UpdateAuthRequest(id string, updater func(old storage.AuthR return } -func (s *memStorage) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) (err error) { +func (s *memStorage) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) (err error) { email = strings.ToLower(email) s.tx(func() { req, ok := s.passwords[email] @@ -438,7 +438,7 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor return } -func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) { +func (s *memStorage) UpdateRefreshToken(ctx context.Context, id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) { s.tx(func() { r, ok := s.refreshTokens[id] if !ok { @@ -452,7 +452,7 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres return } -func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) { +func (s *memStorage) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) { id := offlineSessionID{ userID: userID, connID: connID, @@ -470,7 +470,7 @@ func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater return } -func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector) (storage.Connector, error)) (err error) { +func (s *memStorage) UpdateConnector(ctx context.Context, id string, updater func(c storage.Connector) (storage.Connector, error)) (err error) { s.tx(func() { r, ok := s.connectors[id] if !ok { @@ -495,7 +495,7 @@ func (s *memStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRe return } -func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) { +func (s *memStorage) GetDeviceRequest(ctx context.Context, userCode string) (req storage.DeviceRequest, err error) { s.tx(func() { var ok bool if req, ok = s.deviceRequests[userCode]; !ok { @@ -517,7 +517,7 @@ func (s *memStorage) CreateDeviceToken(ctx context.Context, t storage.DeviceToke return } -func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { +func (s *memStorage) GetDeviceToken(ctx context.Context, deviceCode string) (t storage.DeviceToken, err error) { s.tx(func() { var ok bool if t, ok = s.deviceTokens[deviceCode]; !ok { @@ -528,7 +528,7 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e return } -func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) { +func (s *memStorage) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) { s.tx(func() { r, ok := s.deviceTokens[deviceCode] if !ok { diff --git a/storage/memory/static_test.go b/storage/memory/static_test.go index b913874231..8f58c845ad 100644 --- a/storage/memory/static_test.go +++ b/storage/memory/static_test.go @@ -31,14 +31,14 @@ func TestStaticClients(t *testing.T) { { name: "get client from static storage", action: func() error { - _, err := s.GetClient(c2.ID) + _, err := s.GetClient(ctx, c2.ID) return err }, }, { name: "get client from backing storage", action: func() error { - _, err := s.GetClient(c1.ID) + _, err := s.GetClient(ctx, c1.ID) return err }, }, @@ -49,7 +49,7 @@ func TestStaticClients(t *testing.T) { c.Secret = "new_" + c.Secret return c, nil } - return s.UpdateClient(c2.ID, updater) + return s.UpdateClient(ctx, c2.ID, updater) }, wantErr: true, }, @@ -60,13 +60,13 @@ func TestStaticClients(t *testing.T) { c.Secret = "new_" + c.Secret return c, nil } - return s.UpdateClient(c1.ID, updater) + return s.UpdateClient(ctx, c1.ID, updater) }, }, { name: "list clients", action: func() error { - clients, err := s.ListClients() + clients, err := s.ListClients(ctx) if err != nil { return err } @@ -116,21 +116,21 @@ func TestStaticPasswords(t *testing.T) { { name: "get password from static storage", action: func() error { - _, err := s.GetPassword(p2.Email) + _, err := s.GetPassword(ctx, p2.Email) return err }, }, { name: "get password from backing storage", action: func() error { - _, err := s.GetPassword(p1.Email) + _, err := s.GetPassword(ctx, p1.Email) return err }, }, { name: "get password from static storage with casing", action: func() error { - _, err := s.GetPassword(strings.ToUpper(p2.Email)) + _, err := s.GetPassword(ctx, strings.ToUpper(p2.Email)) return err }, }, @@ -141,7 +141,7 @@ func TestStaticPasswords(t *testing.T) { p.Username = "new_" + p.Username return p, nil } - return s.UpdatePassword(p2.Email, updater) + return s.UpdatePassword(ctx, p2.Email, updater) }, wantErr: true, }, @@ -152,7 +152,7 @@ func TestStaticPasswords(t *testing.T) { p.Username = "new_" + p.Username return p, nil } - return s.UpdatePassword(p1.Email, updater) + return s.UpdatePassword(ctx, p1.Email, updater) }, }, { @@ -168,7 +168,7 @@ func TestStaticPasswords(t *testing.T) { { name: "get password", action: func() error { - p, err := s.GetPassword(p4.Email) + p, err := s.GetPassword(ctx, p4.Email) if err != nil { return err } @@ -181,7 +181,7 @@ func TestStaticPasswords(t *testing.T) { { name: "list passwords", action: func() error { - passwords, err := s.ListPasswords() + passwords, err := s.ListPasswords(ctx) if err != nil { return err } @@ -228,14 +228,14 @@ func TestStaticConnectors(t *testing.T) { { name: "get connector from static storage", action: func() error { - _, err := s.GetConnector(c2.ID) + _, err := s.GetConnector(ctx, c2.ID) return err }, }, { name: "get connector from backing storage", action: func() error { - _, err := s.GetConnector(c1.ID) + _, err := s.GetConnector(ctx, c1.ID) return err }, }, @@ -246,7 +246,7 @@ func TestStaticConnectors(t *testing.T) { c.Name = "New" return c, nil } - return s.UpdateConnector(c2.ID, updater) + return s.UpdateConnector(ctx, c2.ID, updater) }, wantErr: true, }, @@ -257,13 +257,13 @@ func TestStaticConnectors(t *testing.T) { c.Name = "New" return c, nil } - return s.UpdateConnector(c1.ID, updater) + return s.UpdateConnector(ctx, c1.ID, updater) }, }, { name: "list connectors", action: func() error { - connectors, err := s.ListConnectors() + connectors, err := s.ListConnectors(ctx) if err != nil { return err } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 1249243ced..a9ca38167d 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -86,7 +86,7 @@ type scanner interface { var _ storage.Storage = (*conn)(nil) -func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) { +func (c *conn) GarbageCollect(ctc context.Context, now time.Time) (storage.GCResult, error) { result := storage.GCResult{} r, err := c.Exec(`delete from auth_request where expiry < $1`, now) @@ -158,9 +158,9 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err return nil } -func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { +func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getAuthRequest(tx, id) + r, err := getAuthRequest(ctx, tx, id) if err != nil { return err } @@ -200,11 +200,11 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) }) } -func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) +func (c *conn) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { + return getAuthRequest(ctx, c, id) } -func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { +func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` select id, client_id, response_types, scopes, redirect_uri, nonce, state, @@ -258,7 +258,7 @@ func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error { return nil } -func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { +func (c *conn) GetAuthCode(ctx context.Context, id string) (a storage.AuthCode, err error) { err = c.QueryRow(` select id, client_id, scopes, nonce, redirect_uri, @@ -310,9 +310,9 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error return nil } -func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { +func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getRefresh(tx, id) + r, err := getRefresh(ctx, tx, id) if err != nil { return err } @@ -354,11 +354,11 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok }) } -func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return getRefresh(c, id) +func (c *conn) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { + return getRefresh(ctx, c, id) } -func getRefresh(q querier, id string) (storage.RefreshToken, error) { +func getRefresh(ctx context.Context, q querier, id string) (storage.RefreshToken, error) { return scanRefresh(q.QueryRow(` select id, client_id, scopes, nonce, @@ -371,7 +371,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) { `, id)) } -func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { +func (c *conn) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) { rows, err := c.Query(` select id, client_id, scopes, nonce, @@ -418,12 +418,12 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { return r, nil } -func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { +func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { return c.ExecTx(func(tx *trans) error { firstUpdate := false // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. - old, err := getKeys(tx) + old, err := getKeys(ctx, tx) if err != nil { if err != storage.ErrNotFound { return fmt.Errorf("get keys: %v", err) @@ -471,11 +471,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) }) } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - return getKeys(c) +func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) { + return getKeys(ctx, c) } -func getKeys(q querier) (keys storage.Keys, err error) { +func getKeys(ctx context.Context, q querier) (keys storage.Keys, err error) { err = q.QueryRow(` select verification_keys, signing_key, signing_key_pub, next_rotation @@ -494,9 +494,9 @@ func getKeys(q querier) (keys storage.Keys, err error) { return keys, nil } -func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { +func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { return c.ExecTx(func(tx *trans) error { - cli, err := getClient(tx, id) + cli, err := getClient(ctx, tx, id) if err != nil { return err } @@ -543,7 +543,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { return nil } -func getClient(q querier, id string) (storage.Client, error) { +func getClient(ctx context.Context, q querier, id string) (storage.Client, error) { return scanClient(q.QueryRow(` select id, secret, redirect_uris, trusted_peers, public, name, logo_url @@ -551,11 +551,11 @@ func getClient(q querier, id string) (storage.Client, error) { `, id)) } -func (c *conn) GetClient(id string) (storage.Client, error) { - return getClient(c, id) +func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error) { + return getClient(ctx, c, id) } -func (c *conn) ListClients() ([]storage.Client, error) { +func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) { rows, err := c.Query(` select id, secret, redirect_uris, trusted_peers, public, name, logo_url @@ -615,9 +615,9 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error { return nil } -func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { +func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error { return c.ExecTx(func(tx *trans) error { - p, err := getPassword(tx, email) + p, err := getPassword(ctx, tx, email) if err != nil { return err } @@ -641,11 +641,11 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st }) } -func (c *conn) GetPassword(email string) (storage.Password, error) { - return getPassword(c, email) +func (c *conn) GetPassword(ctx context.Context, email string) (storage.Password, error) { + return getPassword(ctx, c, email) } -func getPassword(q querier, email string) (p storage.Password, err error) { +func getPassword(ctx context.Context, q querier, email string) (p storage.Password, err error) { return scanPassword(q.QueryRow(` select email, hash, username, user_id @@ -653,7 +653,7 @@ func getPassword(q querier, email string) (p storage.Password, err error) { `, strings.ToLower(email))) } -func (c *conn) ListPasswords() ([]storage.Password, error) { +func (c *conn) ListPasswords(ctx context.Context) ([]storage.Password, error) { rows, err := c.Query(` select email, hash, username, user_id @@ -711,9 +711,9 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi return nil } -func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { +func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { return c.ExecTx(func(tx *trans) error { - s, err := getOfflineSessions(tx, userID, connID) + s, err := getOfflineSessions(ctx, tx, userID, connID) if err != nil { return err } @@ -738,11 +738,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( }) } -func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) +func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) { + return getOfflineSessions(ctx, c, userID, connID) } -func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { +func getOfflineSessions(ctx context.Context, q querier, userID string, connID string) (storage.OfflineSessions, error) { return scanOfflineSessions(q.QueryRow(` select user_id, conn_id, refresh, connector_data @@ -784,9 +784,9 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) return nil } -func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { +func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error { return c.ExecTx(func(tx *trans) error { - connector, err := getConnector(tx, id) + connector, err := getConnector(ctx, tx, id) if err != nil { return err } @@ -813,11 +813,11 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto }) } -func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) +func (c *conn) GetConnector(ctx context.Context, id string) (storage.Connector, error) { + return getConnector(ctx, c, id) } -func getConnector(q querier, id string) (storage.Connector, error) { +func getConnector(ctx context.Context, q querier, id string) (storage.Connector, error) { return scanConnector(q.QueryRow(` select id, type, name, resource_version, config @@ -839,7 +839,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) { return c, nil } -func (c *conn) ListConnectors() ([]storage.Connector, error) { +func (c *conn) ListConnectors(ctx context.Context) ([]storage.Connector, error) { rows, err := c.Query(` select id, type, name, resource_version, config @@ -864,16 +864,31 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { return connectors, nil } -func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } -func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) } -func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) } -func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) } -func (c *conn) DeletePassword(email string) error { +func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error { + return c.delete("auth_request", "id", id) +} + +func (c *conn) DeleteAuthCode(ctx context.Context, id string) error { + return c.delete("auth_code", "id", id) +} + +func (c *conn) DeleteClient(ctx context.Context, id string) error { + return c.delete("client", "id", id) +} + +func (c *conn) DeleteRefresh(ctx context.Context, id string) error { + return c.delete("refresh_token", "id", id) +} + +func (c *conn) DeletePassword(ctx context.Context, email string) error { return c.delete("password", "email", strings.ToLower(email)) } -func (c *conn) DeleteConnector(id string) error { return c.delete("connector", "id", id) } -func (c *conn) DeleteOfflineSessions(userID string, connID string) error { +func (c *conn) DeleteConnector(ctx context.Context, id string) error { + return c.delete("connector", "id", id) +} + +func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID) if err != nil { return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID) @@ -948,11 +963,11 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err return nil } -func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { - return getDeviceRequest(c, userCode) +func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { + return getDeviceRequest(ctx, c, userCode) } -func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) { +func getDeviceRequest(ctx context.Context, q querier, userCode string) (d storage.DeviceRequest, err error) { err = q.QueryRow(` select device_code, client_id, client_secret, scopes, expiry @@ -970,11 +985,11 @@ func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err return d, nil } -func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { - return getDeviceToken(c, deviceCode) +func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { + return getDeviceToken(ctx, c, deviceCode) } -func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { +func getDeviceToken(ctx context.Context, q querier, deviceCode string) (a storage.DeviceToken, err error) { err = q.QueryRow(` select status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method @@ -992,9 +1007,9 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er return a, nil } -func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { +func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getDeviceToken(tx, deviceCode) + r, err := getDeviceToken(ctx, tx, deviceCode) if err != nil { return err } diff --git a/storage/static.go b/storage/static.go index ca04937acf..386b2b2883 100644 --- a/storage/static.go +++ b/storage/static.go @@ -31,11 +31,11 @@ func WithStaticClients(s Storage, staticClients []Client) Storage { return staticClientsStorage{s, staticClients, clientsByID} } -func (s staticClientsStorage) GetClient(id string) (Client, error) { +func (s staticClientsStorage) GetClient(ctx context.Context, id string) (Client, error) { if client, ok := s.clientsByID[id]; ok { return client, nil } - return s.Storage.GetClient(id) + return s.Storage.GetClient(ctx, id) } func (s staticClientsStorage) isStatic(id string) bool { @@ -43,8 +43,8 @@ func (s staticClientsStorage) isStatic(id string) bool { return ok } -func (s staticClientsStorage) ListClients() ([]Client, error) { - clients, err := s.Storage.ListClients() +func (s staticClientsStorage) ListClients(ctx context.Context) ([]Client, error) { + clients, err := s.Storage.ListClients(ctx) if err != nil { return nil, err } @@ -67,18 +67,18 @@ func (s staticClientsStorage) CreateClient(ctx context.Context, c Client) error return s.Storage.CreateClient(ctx, c) } -func (s staticClientsStorage) DeleteClient(id string) error { +func (s staticClientsStorage) DeleteClient(ctx context.Context, id string) error { if s.isStatic(id) { return errors.New("static clients: read-only cannot delete client") } - return s.Storage.DeleteClient(id) + return s.Storage.DeleteClient(ctx, id) } -func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error { +func (s staticClientsStorage) UpdateClient(ctx context.Context, id string, updater func(old Client) (Client, error)) error { if s.isStatic(id) { return errors.New("static clients: read-only cannot update client") } - return s.Storage.UpdateClient(id, updater) + return s.Storage.UpdateClient(ctx, id, updater) } type staticPasswordsStorage struct { @@ -112,18 +112,18 @@ func (s staticPasswordsStorage) isStatic(email string) bool { return ok } -func (s staticPasswordsStorage) GetPassword(email string) (Password, error) { +func (s staticPasswordsStorage) GetPassword(ctx context.Context, email string) (Password, error) { // TODO(ericchiang): BLAH. We really need to figure out how to handle // lower cased emails better. email = strings.ToLower(email) if password, ok := s.passwordsByEmail[email]; ok { return password, nil } - return s.Storage.GetPassword(email) + return s.Storage.GetPassword(ctx, email) } -func (s staticPasswordsStorage) ListPasswords() ([]Password, error) { - passwords, err := s.Storage.ListPasswords() +func (s staticPasswordsStorage) ListPasswords(ctx context.Context) ([]Password, error) { + passwords, err := s.Storage.ListPasswords(ctx) if err != nil { return nil, err } @@ -147,18 +147,18 @@ func (s staticPasswordsStorage) CreatePassword(ctx context.Context, p Password) return s.Storage.CreatePassword(ctx, p) } -func (s staticPasswordsStorage) DeletePassword(email string) error { +func (s staticPasswordsStorage) DeletePassword(ctx context.Context, email string) error { if s.isStatic(email) { return errors.New("static passwords: read-only cannot delete password") } - return s.Storage.DeletePassword(email) + return s.Storage.DeletePassword(ctx, email) } -func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error { +func (s staticPasswordsStorage) UpdatePassword(ctx context.Context, email string, updater func(old Password) (Password, error)) error { if s.isStatic(email) { return errors.New("static passwords: read-only cannot update password") } - return s.Storage.UpdatePassword(email, updater) + return s.Storage.UpdatePassword(ctx, email, updater) } // staticConnectorsStorage represents a storage with read-only set of connectors. @@ -185,15 +185,15 @@ func (s staticConnectorsStorage) isStatic(id string) bool { return ok } -func (s staticConnectorsStorage) GetConnector(id string) (Connector, error) { +func (s staticConnectorsStorage) GetConnector(ctx context.Context, id string) (Connector, error) { if connector, ok := s.connectorsByID[id]; ok { return connector, nil } - return s.Storage.GetConnector(id) + return s.Storage.GetConnector(ctx, id) } -func (s staticConnectorsStorage) ListConnectors() ([]Connector, error) { - connectors, err := s.Storage.ListConnectors() +func (s staticConnectorsStorage) ListConnectors(ctx context.Context) ([]Connector, error) { + connectors, err := s.Storage.ListConnectors(ctx) if err != nil { return nil, err } @@ -217,16 +217,16 @@ func (s staticConnectorsStorage) CreateConnector(ctx context.Context, c Connecto return s.Storage.CreateConnector(ctx, c) } -func (s staticConnectorsStorage) DeleteConnector(id string) error { +func (s staticConnectorsStorage) DeleteConnector(ctx context.Context, id string) error { if s.isStatic(id) { return errors.New("static connectors: read-only cannot delete connector") } - return s.Storage.DeleteConnector(id) + return s.Storage.DeleteConnector(ctx, id) } -func (s staticConnectorsStorage) UpdateConnector(id string, updater func(old Connector) (Connector, error)) error { +func (s staticConnectorsStorage) UpdateConnector(ctx context.Context, id string, updater func(old Connector) (Connector, error)) error { if s.isStatic(id) { return errors.New("static connectors: read-only cannot update connector") } - return s.Storage.UpdateConnector(id, updater) + return s.Storage.UpdateConnector(ctx, id, updater) } diff --git a/storage/storage.go b/storage/storage.go index 03883ef5aa..574b0a5a5e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -89,30 +89,30 @@ type Storage interface { // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound. - GetAuthRequest(id string) (AuthRequest, error) - GetAuthCode(id string) (AuthCode, error) - GetClient(id string) (Client, error) - GetKeys() (Keys, error) - GetRefresh(id string) (RefreshToken, error) - GetPassword(email string) (Password, error) - GetOfflineSessions(userID string, connID string) (OfflineSessions, error) - GetConnector(id string) (Connector, error) - GetDeviceRequest(userCode string) (DeviceRequest, error) - GetDeviceToken(deviceCode string) (DeviceToken, error) - - ListClients() ([]Client, error) - ListRefreshTokens() ([]RefreshToken, error) - ListPasswords() ([]Password, error) - ListConnectors() ([]Connector, error) + GetAuthRequest(ctx context.Context, id string) (AuthRequest, error) + GetAuthCode(ctx context.Context, id string) (AuthCode, error) + GetClient(ctx context.Context, id string) (Client, error) + GetKeys(ctx context.Context) (Keys, error) + GetRefresh(ctx context.Context, id string) (RefreshToken, error) + GetPassword(ctx context.Context, email string) (Password, error) + GetOfflineSessions(ctx context.Context, userID string, connID string) (OfflineSessions, error) + GetConnector(ctx context.Context, id string) (Connector, error) + GetDeviceRequest(ctx context.Context, userCode string) (DeviceRequest, error) + GetDeviceToken(ctx context.Context, deviceCode string) (DeviceToken, error) + + ListClients(ctx context.Context) ([]Client, error) + ListRefreshTokens(ctx context.Context) ([]RefreshToken, error) + ListPasswords(ctx context.Context) ([]Password, error) + ListConnectors(ctx context.Context) ([]Connector, error) // Delete methods MUST be atomic. - DeleteAuthRequest(id string) error - DeleteAuthCode(code string) error - DeleteClient(id string) error - DeleteRefresh(id string) error - DeletePassword(email string) error - DeleteOfflineSessions(userID string, connID string) error - DeleteConnector(id string) error + DeleteAuthRequest(ctx context.Context, id string) error + DeleteAuthCode(ctx context.Context, code string) error + DeleteClient(ctx context.Context, id string) error + DeleteRefresh(ctx context.Context, id string) error + DeletePassword(ctx context.Context, email string) error + DeleteOfflineSessions(ctx context.Context, userID string, connID string) error + DeleteConnector(ctx context.Context, id string) error // Update methods take a function for updating an object then performs that update within // a transaction. "updater" functions may be called multiple times by a single update call. @@ -128,18 +128,18 @@ type Storage interface { // // update failed, handle error // } // - UpdateClient(id string, updater func(old Client) (Client, error)) error - UpdateKeys(updater func(old Keys) (Keys, error)) error - UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error - UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error - UpdatePassword(email string, updater func(p Password) (Password, error)) error - UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error - UpdateConnector(id string, updater func(c Connector) (Connector, error)) error - UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error + UpdateClient(ctx context.Context, id string, updater func(old Client) (Client, error)) error + UpdateKeys(ctx context.Context, updater func(old Keys) (Keys, error)) error + UpdateAuthRequest(ctx context.Context, id string, updater func(a AuthRequest) (AuthRequest, error)) error + UpdateRefreshToken(ctx context.Context, id string, updater func(r RefreshToken) (RefreshToken, error)) error + UpdatePassword(ctx context.Context, email string, updater func(p Password) (Password, error)) error + UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error + UpdateConnector(ctx context.Context, id string, updater func(c Connector) (Connector, error)) error + UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error // GarbageCollect deletes all expired AuthCodes, // AuthRequests, DeviceRequests, and DeviceTokens. - GarbageCollect(now time.Time) (GCResult, error) + GarbageCollect(ctx context.Context, now time.Time) (GCResult, error) } // Client represents an OAuth2 client.