Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up user's expired login sessions #1113

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 123 additions & 3 deletions codecov_auth/tests/unit/views/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch

import pytest
Expand All @@ -8,10 +8,15 @@
from django.http import HttpResponse
from django.test import RequestFactory, TestCase, override_settings
from freezegun import freeze_time
from shared.django_apps.codecov_auth.tests.factories import OwnerFactory, UserFactory
from shared.django_apps.codecov_auth.tests.factories import (
OwnerFactory,
SessionFactory,
UserFactory,
)
from shared.license import LicenseInformation

from codecov_auth.models import Owner, OwnerProfile
from codecov_auth.models import DjangoSession, Owner, OwnerProfile, Session
from codecov_auth.tests.factories import DjangoSessionFactory
from codecov_auth.views.base import LoginMixin, StateMixin


Expand Down Expand Up @@ -729,3 +734,118 @@ def test_login_authenticated_with_claimed_owner(self):
# does not re-claim owner
assert owner.user is not None
assert owner.user != user

@patch("services.refresh.RefreshService.trigger_refresh", lambda *args: None)
def test_login_owner_with_expired_login_session(self):
user = UserFactory()
owner = OwnerFactory(service="github", user=user)

another_user = UserFactory()
another_owner = OwnerFactory(service="github", user=another_user)

now = datetime.now(timezone.utc)

# Create a session that will be deleted
to_be_deleted_1 = SessionFactory(
owner=owner,
type="login",
name="to_be_deleted",
lastseen="2021-01-01T00:00:00+00:00",
login_session=DjangoSessionFactory(expire_date=now - timedelta(days=1)),
)
to_be_deleted_1_session_key = to_be_deleted_1.login_session.session_key

# Create a session that will not be deleted because its not a login session
to_be_kept_1 = SessionFactory(
owner=owner,
type="api",
name="to_be_kept",
lastseen="2021-01-01T00:00:00+00:00",
login_session=DjangoSessionFactory(expire_date=now + timedelta(days=1)),
)

# Create a session that will not be deleted because it's not expired
to_be_kept_2 = SessionFactory(
owner=owner,
type="login",
name="to_be_kept",
lastseen="2021-01-01T00:00:00+00:00",
login_session=DjangoSessionFactory(expire_date=now + timedelta(days=1)),
)

# Create a session that will not be deleted because it's not the owner's session
to_be_kept_3 = SessionFactory(
owner=another_owner,
type="login",
name="to_be_kept",
lastseen="2021-01-01T00:00:00+00:00",
login_session=DjangoSessionFactory(expire_date=now - timedelta(seconds=1)),
)

assert (
len(DjangoSession.objects.filter(session_key=to_be_deleted_1_session_key))
== 1
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_1.login_session.session_key
)
)
== 1
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_2.login_session.session_key
)
)
== 1
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_3.login_session.session_key
)
)
== 1
)

self.request.user = user
self.mixin_instance.login_owner(owner, self.request, HttpResponse())
owner.refresh_from_db()

new_login_session = Session.objects.filter(name=None)

assert len(new_login_session) == 1
assert len(Session.objects.filter(name="to_be_deleted").all()) == 0
assert len(Session.objects.filter(name="to_be_kept").all()) == 3

assert (
len(DjangoSession.objects.filter(session_key=to_be_deleted_1_session_key))
== 0
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_1.login_session.session_key
)
)
== 1
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_2.login_session.session_key
)
)
== 1
)
assert (
len(
DjangoSession.objects.filter(
session_key=to_be_kept_3.login_session.session_key
)
)
== 1
)
64 changes: 49 additions & 15 deletions codecov_auth/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import re
import uuid
from functools import reduce
from typing import Any
from urllib.parse import parse_qs, urlencode, urlparse

from django.conf import settings
from django.contrib.auth import login, logout
from django.contrib.sessions.models import Session as DjangoSession
from django.core.exceptions import PermissionDenied
from django.db import transaction
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.utils import timezone
from django.utils.timezone import now
from shared.encryption.token import encode_token
from shared.license import LICENSE_ERRORS_MESSAGES, get_current_license

Expand Down Expand Up @@ -59,7 +62,7 @@ class StateMixin(object):

"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.redis = get_redis_connection()
super().__init__(*args, **kwargs)

Expand All @@ -69,7 +72,7 @@ def _session_key(self) -> str:
def _get_key_redis(self, state: str) -> str:
return f"oauth-state-{state}"

def _is_matching_cors_domains(self, url_domain) -> bool:
def _is_matching_cors_domains(self, url_domain: str) -> bool:
# make sure the domain is part of the CORS so that's a safe domain to
# redirect to.
if url_domain in settings.CORS_ALLOWED_ORIGINS:
Expand All @@ -79,7 +82,7 @@ def _is_matching_cors_domains(self, url_domain) -> bool:
return True
return False

def _is_valid_redirection(self, to) -> bool:
def _is_valid_redirection(self, to: str) -> bool:
# make sure the redirect url is from a domain we own
try:
url = urlparse(to)
Expand Down Expand Up @@ -115,11 +118,11 @@ def generate_state(self) -> str:

return state

def verify_state(self, state) -> bool:
def verify_state(self, state: str) -> bool:
state_from_session = self.request.session.get(self._session_key(), None)
return state_from_session and state == state_from_session

def get_redirection_url_from_state(self, state) -> (str, bool):
def get_redirection_url_from_state(self, state: str) -> tuple[str, bool]:
cached_url = self.redis.get(self._get_key_redis(state))

if not cached_url:
Expand Down Expand Up @@ -149,7 +152,7 @@ def get_redirection_url_from_state(self, state) -> (str, bool):
# Return the final redirect URL to complete the login.
return (cached_url.decode("utf-8"), True)

def remove_state(self, state, delay=0) -> None:
def remove_state(self, state: str, delay: int = 0) -> None:
redirection_url, _ = self.get_redirection_url_from_state(state)
if delay == 0:
self.redis.delete(self._get_key_redis(state))
Expand Down Expand Up @@ -182,15 +185,17 @@ def modify_redirection_url_based_on_default_user_org(
url += f"/{owner_profile.default_org.username}"
return url

def get_or_create_org(self, single_organization):
def get_or_create_org(self, single_organization: dict) -> Owner:
owner, was_created = Owner.objects.get_or_create(
service=self.service,
service_id=single_organization["id"],
defaults={"createstamp": timezone.now()},
)
return owner

def login_owner(self, owner: Owner, request: HttpRequest, response: HttpResponse):
def login_owner(
self, owner: Owner, request: HttpRequest, response: HttpResponse
) -> None:
# if there's a currently authenticated user
if request.user is not None and not request.user.is_anonymous:
if owner.user is None:
Expand Down Expand Up @@ -253,9 +258,11 @@ def login_owner(self, owner: Owner, request: HttpRequest, response: HttpResponse

request.session["current_owner_id"] = owner.pk
RefreshService().trigger_refresh(owner.ownerid, owner.username)

self.delete_expired_sessions_and_django_sessions(owner)
self.store_login_session(owner)

def get_and_modify_owner(self, user_dict, request) -> Owner:
def get_and_modify_owner(self, user_dict: dict, request: HttpRequest) -> Owner:
user_orgs = user_dict["orgs"]
formatted_orgs = [
dict(username=org["username"], id=str(org["id"])) for org in user_orgs
Expand Down Expand Up @@ -298,7 +305,9 @@ def get_and_modify_owner(self, user_dict, request) -> Owner:

return owner

def _check_enterprise_organizations_membership(self, user_dict, orgs):
def _check_enterprise_organizations_membership(
self, user_dict: dict, orgs: list[dict]
) -> None:
"""Checks if a user belongs to the restricted organizations (or teams if GitHub) allowed in settings."""
if settings.IS_ENTERPRISE and get_config(self.service, "organizations"):
orgs_in_settings = set(get_config(self.service, "organizations"))
Expand All @@ -315,7 +324,7 @@ def _check_enterprise_organizations_membership(self, user_dict, orgs):
"You must be a member of an allowed team in your organization."
)

def _check_user_count_limitations(self, login_data):
def _check_user_count_limitations(self, login_data: dict) -> None:
if not settings.IS_ENTERPRISE:
return
license = get_current_license()
Expand All @@ -339,7 +348,7 @@ def _check_user_count_limitations(self, login_data):
owners_with_activated_users = Owner.objects.exclude(
plan_activated_users__len=0
).exclude(plan_activated_users__isnull=True)
all_distinct_actiaved_users = reduce(
all_distinct_actiaved_users: set[str] = reduce(
lambda acc, curr: set(curr.plan_activated_users) | acc,
owners_with_activated_users,
set(),
Expand All @@ -357,7 +366,9 @@ def _check_user_count_limitations(self, login_data):
if users_on_service_count > license.number_allowed_users:
raise PermissionDenied(LICENSE_ERRORS_MESSAGES["users-exceeded"])

def _get_or_create_owner(self, user_dict, request):
def _get_or_create_owner(
self, user_dict: dict, request: HttpRequest
) -> tuple[Owner, bool]:
fields_to_update = ["oauth_token", "private_access", "updatestamp"]
login_data = user_dict["user"]
owner, was_created = Owner.objects.get_or_create(
Expand Down Expand Up @@ -403,7 +414,7 @@ def _get_utm_params(self, params: dict) -> dict:
# remove None values from the dict
return {k: v for k, v in filtered_params.items() if v is not None}

def store_to_cookie_utm_tags(self, response) -> None:
def store_to_cookie_utm_tags(self, response: HttpResponse) -> None:
if not settings.IS_ENTERPRISE:
data = urlencode(self._get_utm_params(self.request.GET))
response.set_cookie(
Expand All @@ -423,7 +434,7 @@ def retrieve_marketing_tags_from_cookie(self) -> dict:
else:
return {}

def store_login_session(self, owner: Owner):
def store_login_session(self, owner: Owner) -> None:
# Store user's login session info after logging in
http_x_forwarded_for = self.request.META.get("HTTP_X_FORWARDED_FOR")
if http_x_forwarded_for:
Expand All @@ -443,3 +454,26 @@ def store_login_session(self, owner: Owner):
type=Session.SessionType.LOGIN,
owner=owner,
)

def delete_expired_sessions_and_django_sessions(self, owner: Owner) -> None:
"""
This function deletes expired login sessions for a given owner
"""
with transaction.atomic():
# Get the primary keys of expired DjangoSessions for the given owner
expired_sessions = Session.objects.filter(
owner=owner,
type="login",
login_session__isnull=False,
login_session__expire_date__lt=now(),
)

# Delete the rows in the Session table using sessionid
Session.objects.filter(
sessionid__in=[es.sessionid for es in expired_sessions]
).delete()

# Delete the rows in the DjangoSession table using the extracted keys
DjangoSession.objects.filter(
session_key__in=[es.login_session for es in expired_sessions]
).delete()