Skip to content

Commit

Permalink
feat: rbac middleware (#26159)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Bryan Ciaraldi <[email protected]>
  • Loading branch information
3 people authored and thmsobrmlr committed Nov 25, 2024
1 parent 3b1da84 commit 6dd8a4a
Show file tree
Hide file tree
Showing 56 changed files with 11,103 additions and 5,145 deletions.
194 changes: 194 additions & 0 deletions ee/api/rbac/access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import TYPE_CHECKING, cast


from rest_framework import exceptions, serializers, status
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from ee.models.rbac.access_control import AccessControl
from posthog.models.scopes import API_SCOPE_OBJECTS, APIScopeObjectOrNotSupported
from posthog.models.team.team import Team
from posthog.rbac.user_access_control import (
ACCESS_CONTROL_LEVELS_RESOURCE,
UserAccessControl,
default_access_level,
highest_access_level,
ordered_access_levels,
)


if TYPE_CHECKING:
_GenericViewSet = GenericViewSet
else:
_GenericViewSet = object


class AccessControlSerializer(serializers.ModelSerializer):
access_level = serializers.CharField(allow_null=True)

class Meta:
model = AccessControl
fields = [
"access_level",
"resource",
"resource_id",
"organization_member",
"role",
"created_by",
"created_at",
"updated_at",
]
read_only_fields = ["id", "created_at", "created_by"]

# Validate that resource is a valid option from the API_SCOPE_OBJECTS
def validate_resource(self, resource):
if resource not in API_SCOPE_OBJECTS:
raise serializers.ValidationError("Invalid resource. Must be one of: {}".format(API_SCOPE_OBJECTS))

return resource

# Validate that access control is a valid option
def validate_access_level(self, access_level):
if access_level and access_level not in ordered_access_levels(self.initial_data["resource"]):
raise serializers.ValidationError(
f"Invalid access level. Must be one of: {', '.join(ordered_access_levels(self.initial_data['resource']))}"
)

return access_level

def validate(self, data):
context = self.context

# Ensure that only one of organization_member or role is set
if data.get("organization_member") and data.get("role"):
raise serializers.ValidationError("You can not scope an access control to both a member and a role.")

access_control = cast(UserAccessControl, self.context["view"].user_access_control)
resource = data["resource"]
resource_id = data.get("resource_id")

# We assume the highest level is required for the given resource to edit access controls
required_level = highest_access_level(resource)
team = context["view"].team
the_object = context["view"].get_object()

if resource_id:
# Check that they have the right access level for this specific resource object
if not access_control.check_can_modify_access_levels_for_object(the_object):
raise exceptions.PermissionDenied(f"Must be {required_level} to modify {resource} permissions.")
else:
# If modifying the base resource rules then we are checking the parent membership (project or organization)
# NOTE: Currently we only support org level in the UI so its simply an org level check
if not access_control.check_can_modify_access_levels_for_object(team):
raise exceptions.PermissionDenied("Must be an Organization admin to modify project-wide permissions.")

return data


class AccessControlViewSetMixin(_GenericViewSet):
"""
Adds an "access_controls" action to the viewset that handles access control for the given resource
Why a mixin? We want to easily add this to any existing resource, including providing easy helpers for adding access control info such
as the current users access level to any response.
"""

# 1. Know that the project level access is covered by the Permission check
# 2. Get the actual object which we can pass to the serializer to check if the user created it
# 3. We can also use the serializer to check the access level for the object

def _get_access_control_serializer(self, *args, **kwargs):
kwargs.setdefault("context", self.get_serializer_context())
return AccessControlSerializer(*args, **kwargs)

def _get_access_controls(self, request: Request, is_global=False):
resource = cast(APIScopeObjectOrNotSupported, getattr(self, "scope_object", None))
user_access_control = cast(UserAccessControl, self.user_access_control) # type: ignore
team = cast(Team, self.team) # type: ignore

if is_global and resource != "project" or not resource or resource == "INTERNAL":
raise exceptions.NotFound("Role based access controls are only available for projects.")

obj = self.get_object()
resource_id = obj.id

if is_global:
# If role based then we are getting all controls for the project that aren't specific to a resource
access_controls = AccessControl.objects.filter(team=team, resource_id=None).all()
else:
# Otherwise we are getting all controls for the specific resource
access_controls = AccessControl.objects.filter(team=team, resource=resource, resource_id=resource_id).all()

serializer = self._get_access_control_serializer(instance=access_controls, many=True)
user_access_level = user_access_control.access_level_for_object(obj, resource)

return Response(
{
"access_controls": serializer.data,
# NOTE: For Role based controls we are always configuring resource level items
"available_access_levels": ACCESS_CONTROL_LEVELS_RESOURCE
if is_global
else ordered_access_levels(resource),
"default_access_level": "editor" if is_global else default_access_level(resource),
"user_access_level": user_access_level,
"user_can_edit_access_levels": user_access_control.check_can_modify_access_levels_for_object(obj),
}
)

def _update_access_controls(self, request: Request, is_global=False):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = str(obj.id)
team = cast(Team, self.team) # type: ignore

# Generically validate the incoming data
if not is_global:
# If not role based we are deriving from the viewset
data = request.data
data["resource"] = resource
data["resource_id"] = resource_id

partial_serializer = self._get_access_control_serializer(data=request.data)
partial_serializer.is_valid(raise_exception=True)
params = partial_serializer.validated_data

instance = AccessControl.objects.filter(
team=team,
resource=params["resource"],
resource_id=params.get("resource_id"),
organization_member=params.get("organization_member"),
role=params.get("role"),
).first()

if params["access_level"] is None:
if instance:
instance.delete()
return Response(status=status.HTTP_204_NO_CONTENT)

# Perform the upsert
if instance:
serializer = self._get_access_control_serializer(instance, data=request.data)
else:
serializer = self._get_access_control_serializer(data=request.data)

serializer.is_valid(raise_exception=True)
serializer.validated_data["team"] = team
serializer.save()

return Response(serializer.data, status=status.HTTP_200_OK)

@action(methods=["GET", "PUT"], detail=True)
def access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request)

return self._get_access_controls(request)

@action(methods=["GET", "PUT"], detail=True)
def global_access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request, is_global=True)

return self._get_access_controls(request, is_global=True)
25 changes: 2 additions & 23 deletions ee/api/rbac/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from rest_framework import mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models import OrganizationMembership
from posthog.models.feature_flag import FeatureFlag
from posthog.models.user import User


Expand All @@ -38,7 +36,6 @@ def has_permission(self, request, view):
class RoleSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
members = serializers.SerializerMethodField()
associated_flags = serializers.SerializerMethodField()

class Meta:
model = Role
Expand All @@ -49,7 +46,6 @@ class Meta:
"created_at",
"created_by",
"members",
"associated_flags",
]
read_only_fields = ["id", "created_at", "created_by"]

Expand All @@ -75,29 +71,12 @@ def get_members(self, role: Role):
members = RoleMembership.objects.filter(role=role)
return RoleMembershipSerializer(members, many=True).data

def get_associated_flags(self, role: Role):
associated_flags: list[dict] = []

role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id")
flags = FeatureFlag.objects.filter(id__in=role_access_objects)
for flag in flags:
associated_flags.append({"id": flag.id, "key": flag.key})
return associated_flags


class RoleViewSet(
TeamAndOrgViewSetMixin,
mixins.ListModelMixin,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
class RoleViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "organization"
permission_classes = [RolePermissions]
serializer_class = RoleSerializer
queryset = Role.objects.all()
permission_classes = [RolePermissions]

def safely_get_queryset(self, queryset):
return queryset.filter(**self.request.GET.dict())
Expand Down
Loading

0 comments on commit 6dd8a4a

Please sign in to comment.