Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rbac middleware #26159

Merged
merged 38 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f210043
CREATE_METHODS -> CREATE_ACTIONS
zlwaterfield Nov 12, 2024
9df71cb
Create user_access_control.py
zlwaterfield Nov 12, 2024
8597073
Create access_control.py
zlwaterfield Nov 12, 2024
20ab359
Update permissions.py
zlwaterfield Nov 12, 2024
4c109e9
Create access_control_api_mixin.py
zlwaterfield Nov 12, 2024
9c0767d
Update utils.py
zlwaterfield Nov 12, 2024
b9be974
Update routing.py
zlwaterfield Nov 12, 2024
1b8020c
Update ee tests query counts
zlwaterfield Nov 13, 2024
071edd0
Update org access error messages
zlwaterfield Nov 13, 2024
1c6b123
Fix recursion issue
zlwaterfield Nov 13, 2024
9d7e653
Update snapshots
zlwaterfield Nov 13, 2024
006b011
Update query tests
zlwaterfield Nov 14, 2024
c5f767a
Merge master in
zlwaterfield Nov 14, 2024
8506643
Update test_dashboard.py
zlwaterfield Nov 14, 2024
b6e5590
Create test_user_access_control.py
zlwaterfield Nov 14, 2024
982881b
Add access control to viewsets
zlwaterfield Nov 14, 2024
a3ec406
Create test_access_control.py
zlwaterfield Nov 14, 2024
54070f9
Update organization.py
zlwaterfield Nov 14, 2024
57752fa
Update role.py
zlwaterfield Nov 14, 2024
e55965c
Update test_organization_feature_flag.py
zlwaterfield Nov 14, 2024
cdfb2f3
Fix more tests
zlwaterfield Nov 14, 2024
9b95962
LINT
zlwaterfield Nov 14, 2024
0452dd4
Update query snapshots
github-actions[bot] Nov 14, 2024
0e6effb
update some more quest assertions
zlwaterfield Nov 14, 2024
d639fb4
Update query snapshots
github-actions[bot] Nov 14, 2024
e634a02
Update test_feature_flag.py
zlwaterfield Nov 15, 2024
f209ae8
Merge remote-tracking branch 'origin/master' into zach/rbac/3
zlwaterfield Nov 19, 2024
c04e1ee
Push back role permissions
zlwaterfield Nov 19, 2024
42cd9fc
Update query snapshots
github-actions[bot] Nov 19, 2024
33be661
Update query snapshots
github-actions[bot] Nov 19, 2024
f7ca1c6
Put back role permissions
zlwaterfield Nov 19, 2024
eb11ca7
Remove dup function
zlwaterfield Nov 19, 2024
3fff547
Add comment
zlwaterfield Nov 19, 2024
fc649d8
push image
bciaraldi Nov 20, 2024
431899a
missing depot cli
bciaraldi Nov 20, 2024
85bbbc5
Add checks for organization id
zlwaterfield Nov 21, 2024
3d4fc8d
Add org deletion test
zlwaterfield Nov 21, 2024
c451d6e
revert ci changes
zlwaterfield Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions ee/api/rbac/access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import TYPE_CHECKING, cast


from rest_framework import exceptions, serializers, status
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from ee.models.rbac.access_control import AccessControl
from posthog.models.scopes import API_SCOPE_OBJECTS, APIScopeObjectOrNotSupported
from posthog.models.team.team import Team
from posthog.rbac.user_access_control import (
ACCESS_CONTROL_LEVELS_RESOURCE,
UserAccessControl,
default_access_level,
highest_access_level,
ordered_access_levels,
)


if TYPE_CHECKING:
_GenericViewSet = GenericViewSet
else:
_GenericViewSet = object


class AccessControlSerializer(serializers.ModelSerializer):
access_level = serializers.CharField(allow_null=True)

class Meta:
model = AccessControl
fields = [
"access_level",
"resource",
"resource_id",
"organization_member",
"role",
"created_by",
"created_at",
"updated_at",
]
read_only_fields = ["id", "created_at", "created_by"]

# Validate that resource is a valid option from the API_SCOPE_OBJECTS
def validate_resource(self, resource):
if resource not in API_SCOPE_OBJECTS:
raise serializers.ValidationError("Invalid resource. Must be one of: {}".format(API_SCOPE_OBJECTS))

return resource

# Validate that access control is a valid option
def validate_access_level(self, access_level):
if access_level and access_level not in ordered_access_levels(self.initial_data["resource"]):
raise serializers.ValidationError(
f"Invalid access level. Must be one of: {', '.join(ordered_access_levels(self.initial_data['resource']))}"
)

return access_level

def validate(self, data):
context = self.context

# Ensure that only one of organization_member or role is set
if data.get("organization_member") and data.get("role"):
raise serializers.ValidationError("You can not scope an access control to both a member and a role.")

access_control = cast(UserAccessControl, self.context["view"].user_access_control)
resource = data["resource"]
resource_id = data.get("resource_id")

# We assume the highest level is required for the given resource to edit access controls
required_level = highest_access_level(resource)
team = context["view"].team
the_object = context["view"].get_object()

if resource_id:
# Check that they have the right access level for this specific resource object
if not access_control.check_can_modify_access_levels_for_object(the_object):
raise exceptions.PermissionDenied(f"Must be {required_level} to modify {resource} permissions.")
else:
# If modifying the base resource rules then we are checking the parent membership (project or organization)
# NOTE: Currently we only support org level in the UI so its simply an org level check
if not access_control.check_can_modify_access_levels_for_object(team):
raise exceptions.PermissionDenied("Must be an Organization admin to modify project-wide permissions.")

return data


class AccessControlViewSetMixin(_GenericViewSet):
"""
Adds an "access_controls" action to the viewset that handles access control for the given resource
Why a mixin? We want to easily add this to any existing resource, including providing easy helpers for adding access control info such
as the current users access level to any response.
"""

# 1. Know that the project level access is covered by the Permission check
# 2. Get the actual object which we can pass to the serializer to check if the user created it
# 3. We can also use the serializer to check the access level for the object

def _get_access_control_serializer(self, *args, **kwargs):
kwargs.setdefault("context", self.get_serializer_context())
return AccessControlSerializer(*args, **kwargs)

def _get_access_controls(self, request: Request, is_global=False):
resource = cast(APIScopeObjectOrNotSupported, getattr(self, "scope_object", None))
user_access_control = cast(UserAccessControl, self.user_access_control) # type: ignore
team = cast(Team, self.team) # type: ignore

if is_global and resource != "project" or not resource or resource == "INTERNAL":
raise exceptions.NotFound("Role based access controls are only available for projects.")

obj = self.get_object()
resource_id = obj.id

if is_global:
# If role based then we are getting all controls for the project that aren't specific to a resource
access_controls = AccessControl.objects.filter(team=team, resource_id=None).all()
else:
# Otherwise we are getting all controls for the specific resource
access_controls = AccessControl.objects.filter(team=team, resource=resource, resource_id=resource_id).all()

serializer = self._get_access_control_serializer(instance=access_controls, many=True)
user_access_level = user_access_control.access_level_for_object(obj, resource)

return Response(
{
"access_controls": serializer.data,
# NOTE: For Role based controls we are always configuring resource level items
"available_access_levels": ACCESS_CONTROL_LEVELS_RESOURCE
if is_global
else ordered_access_levels(resource),
"default_access_level": "editor" if is_global else default_access_level(resource),
"user_access_level": user_access_level,
"user_can_edit_access_levels": user_access_control.check_can_modify_access_levels_for_object(obj),
}
)

def _update_access_controls(self, request: Request, is_global=False):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = str(obj.id)
team = cast(Team, self.team) # type: ignore

# Generically validate the incoming data
if not is_global:
# If not role based we are deriving from the viewset
data = request.data
data["resource"] = resource
data["resource_id"] = resource_id

partial_serializer = self._get_access_control_serializer(data=request.data)
partial_serializer.is_valid(raise_exception=True)
params = partial_serializer.validated_data

instance = AccessControl.objects.filter(
team=team,
resource=params["resource"],
resource_id=params.get("resource_id"),
organization_member=params.get("organization_member"),
role=params.get("role"),
).first()

if params["access_level"] is None:
if instance:
instance.delete()
return Response(status=status.HTTP_204_NO_CONTENT)

# Perform the upsert
if instance:
serializer = self._get_access_control_serializer(instance, data=request.data)
else:
serializer = self._get_access_control_serializer(data=request.data)

serializer.is_valid(raise_exception=True)
serializer.validated_data["team"] = team
serializer.save()

return Response(serializer.data, status=status.HTTP_200_OK)

@action(methods=["GET", "PUT"], detail=True)
def access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request)

return self._get_access_controls(request)

@action(methods=["GET", "PUT"], detail=True)
def global_access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request, is_global=True)

return self._get_access_controls(request, is_global=True)
25 changes: 2 additions & 23 deletions ee/api/rbac/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from rest_framework import mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models import OrganizationMembership
from posthog.models.feature_flag import FeatureFlag
from posthog.models.user import User


Expand All @@ -38,7 +36,6 @@ def has_permission(self, request, view):
class RoleSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
members = serializers.SerializerMethodField()
associated_flags = serializers.SerializerMethodField()

class Meta:
model = Role
Expand All @@ -49,7 +46,6 @@ class Meta:
"created_at",
"created_by",
"members",
"associated_flags",
]
read_only_fields = ["id", "created_at", "created_by"]

Expand All @@ -75,29 +71,12 @@ def get_members(self, role: Role):
members = RoleMembership.objects.filter(role=role)
return RoleMembershipSerializer(members, many=True).data

def get_associated_flags(self, role: Role):
associated_flags: list[dict] = []

role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id")
flags = FeatureFlag.objects.filter(id__in=role_access_objects)
for flag in flags:
associated_flags.append({"id": flag.id, "key": flag.key})
return associated_flags


class RoleViewSet(
TeamAndOrgViewSetMixin,
mixins.ListModelMixin,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
class RoleViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "organization"
permission_classes = [RolePermissions]
serializer_class = RoleSerializer
queryset = Role.objects.all()
permission_classes = [RolePermissions]

def safely_get_queryset(self, queryset):
return queryset.filter(**self.request.GET.dict())
Expand Down
Loading
Loading