diff --git a/flyteadmin/auth/config/config.go b/flyteadmin/auth/config/config.go index f96c5cf0ae..9cd0d8bde9 100644 --- a/flyteadmin/auth/config/config.go +++ b/flyteadmin/auth/config/config.go @@ -126,6 +126,13 @@ var ( }, }, }, + Rbac: Rbac{ + Enabled: false, + BypassMethodPatterns: []string{ + "/grpc.health.v1.Health/.*", // health checking for k8s + "/flyteidl.service.AuthMetadataService/.*", // auth metadata used by other Flyte services + }, + }, } cfgSection = config.MustRegisterSection("auth", DefaultConfig) @@ -164,6 +171,8 @@ type Config struct { // AppAuth settings used to authenticate and control/limit access scopes for apps. AppAuth OAuth2Options `json:"appAuth" pflag:",Defines Auth options for apps. UserAuth must be enabled for AppAuth to work."` + + Rbac Rbac `json:"rbacConfig" pflag:",Defines RBAC options for Flyte Admin."` } type AuthorizationServer struct { @@ -236,6 +245,43 @@ type UserAuthConfig struct { IDPQueryParameter string `json:"idpQueryParameter" pflag:", idp query parameter used for selecting a particular IDP for doing user authentication. Eg: for Okta passing idp= forces the authentication to happen with IDP-ID"` } +type Rbac struct { + Enabled bool `json:"enabled" pflag:",Enables RBAC."` + BypassMethodPatterns []string `json:"bypassMethodPatterns" pflag:",List of regex patterns to match against method names to bypass RBAC."` + TokenScopeRoleResolver TokenScopeRoleResolver `json:"tokenScopeRoleResolver" pflag:",Config to use for resolving roles from token scopes."` + TokenClaimRoleResolver TokenClaimRoleResolver `json:"tokenClaimRoleResolver" pflag:",Config to use for resolving roles from token claims."` + Policies []AuthorizationPolicy `json:"policies" pflag:",Authorization policies to use for RBAC."` +} + +// An AuthorizationPolicy represents authorization allow rules. +type AuthorizationPolicy struct { + Role string `json:"role" pflag:",Role to match against."` + Rules []Rule `json:"rules" pflag:",Allow rules for matching requests."` +} + +// A Rule is a struct that represents an API request to match on. +type Rule struct { + MethodPattern string `json:"methodPattern" pflag:",Regex pattern for the gRPC method of the request."` + Project string `json:"project" pflag:",Project level resource scope, empty is wildcard."` + Domain string `json:"domain" pflag:",Domain level resource scope, empty is wildcard."` + Name string `json:"name" pflag:",Scope of the rule."` +} + +// A TokenClaimRoleResolver is a struct that represents how token claims can map to RBAC roles. +type TokenClaimRoleResolver struct { + Enabled bool `json:"enabled" pflag:",Enables token claim based role resolution."` + TokenClaims []TokenClaim `json:"tokenClaims" pflag:",List of claims to use for role resolution."` +} + +type TokenScopeRoleResolver struct { + Enabled bool `json:"enabled" pflag:",Enables token scope based role resolution."` +} + +// A TokenClaim is a struct that describes which claims to look for in tokens in order to use the values as RBAC roles. +type TokenClaim struct { + Name string `json:"name" pflag:",Scope of the claim to look for in the token."` +} + //go:generate enumer --type=SameSite --trimprefix=SameSite -json type SameSite int diff --git a/flyteadmin/auth/interceptors/interceptorstest/test_utils.go b/flyteadmin/auth/interceptors/interceptorstest/test_utils.go new file mode 100644 index 0000000000..e82c393ea9 --- /dev/null +++ b/flyteadmin/auth/interceptors/interceptorstest/test_utils.go @@ -0,0 +1,36 @@ +package interceptorstest + +import "context" + +// TestUnaryHandler is an implementation of grpc.UnaryHandler for test purposes +type TestUnaryHandler struct { + Err error + handleCallCount int + capturedCtx context.Context + HandleFunc func(ctx context.Context) +} + +func (h *TestUnaryHandler) Handle(ctx context.Context, req interface{}) (interface{}, error) { + h.handleCallCount++ + h.capturedCtx = ctx + + if h.HandleFunc != nil { + h.HandleFunc(ctx) + } + + if h.Err != nil { + return nil, h.Err + } + + return nil, nil +} + +// GetHandleCallCount gets the number of times the handle method was called +func (h *TestUnaryHandler) GetHandleCallCount() int { + return h.handleCallCount +} + +// GetCapturedCtx gets the context captured during the last handle method call +func (h *TestUnaryHandler) GetCapturedCtx() context.Context { + return h.capturedCtx +} diff --git a/flyteadmin/auth/interceptors/rbac.go b/flyteadmin/auth/interceptors/rbac.go new file mode 100644 index 0000000000..c519fa84f6 --- /dev/null +++ b/flyteadmin/auth/interceptors/rbac.go @@ -0,0 +1,211 @@ +package interceptors + +import ( + "context" + "fmt" + "github.com/flyteorg/flyte/flyteadmin/auth" + "github.com/flyteorg/flyte/flyteadmin/auth/config" + "github.com/flyteorg/flyte/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyte/flyteadmin/auth/isolation" + "github.com/flyteorg/flyte/flytestdlib/logger" + "golang.org/x/exp/maps" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "regexp" +) + +func GetAuthorizationInterceptor(authCtx interfaces.AuthenticationContext) (grpc.UnaryServerInterceptor, error) { + + noopFunc := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return nil, nil + } + + opts := authCtx.Options().Rbac + for _, policy := range opts.Policies { + // FIXME: Move this to somewhere else? + err := validatePolicy(policy) + if err != nil { + return noopFunc, fmt.Errorf("failed to validate authorization policy: %w", err) + } + } + + bypassMethodPatterns := []*regexp.Regexp{} + + for _, allowedMethod := range opts.BypassMethodPatterns { + compiled, err := regexp.Compile(allowedMethod) + if err != nil { + return noopFunc, fmt.Errorf("compiling bypass method pattern %s: %w", allowedMethod, err) + } + + bypassMethodPatterns = append(bypassMethodPatterns, compiled) + } + + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + + for _, allowedMethod := range bypassMethodPatterns { + if allowedMethod.MatchString(info.FullMethod) { + logger.Debugf(ctx, "[%s] Authorization bypassed for method", info.FullMethod) + return handler(ctx, req) + } + } + + identityContext := auth.IdentityContextFromContext(ctx) + roles := resolveRoles(opts, identityContext) + + if len(roles) == 0 { + logger.Debugf(ctx, "[%s]No roles resolved. Unauthorized.", info.FullMethod) + return nil, status.Errorf(codes.PermissionDenied, "") + } + + logger.Debugf(ctx, "[%s]Found roles: %s", info.FullMethod, roles) + + authorizedResourceScopes, err := calculateAuthorizedResourceScopes(ctx, roles, opts.Policies, info) + if err != nil { + logger.Errorf(ctx, "[%s]Failed to calculate authorized scopes for user %s: %+v", info.FullMethod, identityContext.UserID(), err) + return nil, status.Errorf(codes.Internal, "") + } + + if len(authorizedResourceScopes) == 0 { + logger.Debugf(ctx, "[%s]Found no matching authorization policy rules. Unauthorized.", info.FullMethod) + return nil, status.Errorf(codes.PermissionDenied, "") + } + + // Add authorized resource scopes to context + isolationContext := isolation.NewIsolationContext(authorizedResourceScopes) + + isolationCtx := isolationContext.WithContext(ctx) + return handler(isolationCtx, req) + + }, nil +} + +func resolveRoles(rbac config.Rbac, identityContext auth.IdentityContext) []string { + + roleSet := map[string]bool{} + + if rbac.TokenScopeRoleResolver.Enabled { + + for _, scopeRole := range identityContext.Scopes().List() { + roleSet[scopeRole] = true + } + } + + if rbac.TokenClaimRoleResolver.Enabled { + claimRoles := resolveRolesViaClaims(identityContext.Claims(), rbac.TokenClaimRoleResolver.TokenClaims) + + for _, claimRole := range claimRoles { + roleSet[claimRole] = true + } + } + + return maps.Keys(roleSet) +} + +func resolveRolesViaClaims(claims map[string]interface{}, targetClaims []config.TokenClaim) []string { + roleSet := map[string]bool{} + + for _, targetClaim := range targetClaims { + claimIntf, ok := claims[targetClaim.Name] + if !ok { + continue + } + + claimString, ok := claimIntf.(string) + if ok { + roleSet[claimString] = true + continue + } + + claimListElements, ok := claimIntf.([]interface{}) + if ok { + for _, claimListElement := range claimListElements { + claimStringElement, ok := claimListElement.(string) + if ok { + roleSet[claimStringElement] = true + } + } + } + } + + return maps.Keys(roleSet) +} + +func calculateAuthorizedResourceScopes(ctx context.Context, roles []string, policies []config.AuthorizationPolicy, info *grpc.UnaryServerInfo) ([]isolation.ResourceScope, error) { + authorizedScopes := []isolation.ResourceScope{} + + policiesByRole := map[string]config.AuthorizationPolicy{} + for _, policy := range policies { + policiesByRole[policy.Role] = policy + } + + matchingPolicies := map[string]config.AuthorizationPolicy{} + for _, role := range roles { + policy, ok := policiesByRole[role] + if !ok { + continue + } + + matchingPolicies[role] = policy + } + + logger.Debugf(ctx, "[%s]Found matching authorization policies: %s", info.FullMethod, matchingPolicies) + + for role, policy := range matchingPolicies { + matchingRules, err := authorizationPolicyMatchesRequest(policy, info) + if err != nil { + return authorizedScopes, fmt.Errorf("failed to match request: %w", err) + } + + if len(matchingRules) > 0 { + logger.Debugf(ctx, "[%s]Found matching rules for role %s: %s", info.FullMethod, role, matchingRules) + for _, matchingRule := range matchingRules { + authorizedScopes = append(authorizedScopes, isolation.ResourceScope{ + Project: matchingRule.Project, + Domain: matchingRule.Domain, + }) + } + } else { + logger.Debugf(ctx, "[%s]Found no matching rules for role %s", info.FullMethod, role) + } + } + + return authorizedScopes, nil +} + +func authorizationPolicyMatchesRequest(ap config.AuthorizationPolicy, info *grpc.UnaryServerInfo) ([]config.Rule, error) { + matchingRules := []config.Rule{} + for _, rule := range ap.Rules { + matches, err := ruleMatchesRequest(rule, info) + if err != nil { + return []config.Rule{}, fmt.Errorf("matching rule against request: %w", err) + } + + if !matches { + continue + } + + matchingRules = append(matchingRules, rule) + } + + return matchingRules, nil +} + +func ruleMatchesRequest(rule config.Rule, info *grpc.UnaryServerInfo) (bool, error) { + pattern, err := regexp.Compile(rule.MethodPattern) + if err != nil { + return false, fmt.Errorf("compiling rule pattern %s: %w", rule.MethodPattern, err) + } + + return pattern.MatchString(info.FullMethod), nil +} + +func validatePolicy(ap config.AuthorizationPolicy) error { + for _, rule := range ap.Rules { + if rule.Project == "" && rule.Domain != "" { + return fmt.Errorf("authorization policy rule %s has invalid resource scope", rule.Name) + } + } + + return nil +} diff --git a/flyteadmin/auth/interceptors/rbac_test.go b/flyteadmin/auth/interceptors/rbac_test.go new file mode 100644 index 0000000000..4fd308e25f --- /dev/null +++ b/flyteadmin/auth/interceptors/rbac_test.go @@ -0,0 +1,393 @@ +package interceptors + +import ( + "context" + "github.com/flyteorg/flyte/flyteadmin/auth" + "github.com/flyteorg/flyte/flyteadmin/auth/config" + "github.com/flyteorg/flyte/flyteadmin/auth/interceptors/interceptorstest" + "github.com/flyteorg/flyte/flyteadmin/auth/interfaces/mocks" + "github.com/flyteorg/flyte/flyteadmin/auth/isolation" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/sets" + "testing" + "time" +) + +func TestGetAuthorizationInterceptor(t *testing.T) { + + t.Run("policy validation fails", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + { + Role: "admin", + Rules: []config.Rule{ + { + Name: "example", + MethodPattern: ".*", + Project: "", + Domain: "development", + }, + }, + }, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + _, err := GetAuthorizationInterceptor(authCtx) + require.ErrorContains(t, err, "authorization policy rule example has invalid resource scope") + }) +} + +func TestAuthorizationInterceptor(t *testing.T) { + + logger.SetConfig(&logger.Config{Level: logger.DebugLevel}) + ctx := context.Background() + + info := &grpc.UnaryServerInfo{ + FullMethod: "ExampleMethod", + } + + adminAuthPolicy := config.AuthorizationPolicy{ + Role: "admin", + Rules: []config.Rule{ + { + Name: "example", + MethodPattern: ".*", + Project: "flytesnacks", + Domain: "development", + }, + }, + } + + t.Run("bypass method pattern wildcard match", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + BypassMethodPatterns: []string{".*"}, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + _, err = interceptor(ctx, nil, info, handler.Handle) + require.NoError(t, err) + require.Equal(t, 1, handler.GetHandleCallCount()) + + isolationCtx := isolation.IsolationContextFromContext(handler.GetCapturedCtx()) + require.Empty(t, isolationCtx.GetResourceScopes()) + }) + + t.Run("bypass method pattern exact match", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + BypassMethodPatterns: []string{"ExampleMethod"}, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + _, err = interceptor(ctx, nil, info, handler.Handle) + require.NoError(t, err) + require.Equal(t, 1, handler.GetHandleCallCount()) + + isolationCtx := isolation.IsolationContextFromContext(handler.GetCapturedCtx()) + require.Empty(t, isolationCtx.GetResourceScopes()) + }) + + t.Run("bypass method pattern no match", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + BypassMethodPatterns: []string{"NoMethod"}, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + _, err = interceptor(ctx, nil, info, handler.Handle) + require.ErrorIs(t, err, status.Errorf(codes.PermissionDenied, "")) + require.Equal(t, 0, handler.GetHandleCallCount()) + }) + + t.Run("authorization fails due to no roles", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + _, err = interceptor(ctx, nil, info, handler.Handle) + require.ErrorIs(t, err, status.Errorf(codes.PermissionDenied, "")) + require.Equal(t, 0, handler.GetHandleCallCount()) + }) + + t.Run("authorization success with scope based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenScopeRoleResolver: config.TokenScopeRoleResolver{ + Enabled: true, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + scopes := sets.NewString("admin") + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), scopes, nil, nil) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.NoError(t, err) + require.Equal(t, 1, handler.GetHandleCallCount()) + + isolationCtx := isolation.IsolationContextFromContext(handler.GetCapturedCtx()) + require.Len(t, isolationCtx.GetResourceScopes(), 1) + + resourceScope := isolationCtx.GetResourceScopes()[0] + expectedResourceScope := isolation.ResourceScope{ + Project: "flytesnacks", + Domain: "development", + } + require.Equal(t, expectedResourceScope, resourceScope) + }) + + t.Run("authorization fails with scope based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenScopeRoleResolver: config.TokenScopeRoleResolver{ + Enabled: true, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + scopes := sets.NewString("notadmin") + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), scopes, nil, nil) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.ErrorIs(t, err, status.Errorf(codes.PermissionDenied, "")) + require.Equal(t, 0, handler.GetHandleCallCount()) + }) + + t.Run("authorization success with string claim based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenClaimRoleResolver: config.TokenClaimRoleResolver{ + Enabled: true, + TokenClaims: []config.TokenClaim{ + { + Name: "group", + }, + }, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + claims := map[string]interface{}{ + "group": "admin", + } + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), nil, nil, claims) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.NoError(t, err) + require.Equal(t, 1, handler.GetHandleCallCount()) + + isolationCtx := isolation.IsolationContextFromContext(handler.GetCapturedCtx()) + require.Len(t, isolationCtx.GetResourceScopes(), 1) + + resourceScope := isolationCtx.GetResourceScopes()[0] + expectedResourceScope := isolation.ResourceScope{ + Project: "flytesnacks", + Domain: "development", + } + require.Equal(t, expectedResourceScope, resourceScope) + }) + + t.Run("authorization fails with string claim based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenClaimRoleResolver: config.TokenClaimRoleResolver{ + Enabled: true, + TokenClaims: []config.TokenClaim{ + { + Name: "group", + }, + }, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + claims := map[string]interface{}{ + "group": "notadmin", + } + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), nil, nil, claims) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.ErrorIs(t, err, status.Errorf(codes.PermissionDenied, "")) + require.Equal(t, 0, handler.GetHandleCallCount()) + }) + + t.Run("authorization success with string list claim based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenClaimRoleResolver: config.TokenClaimRoleResolver{ + Enabled: true, + TokenClaims: []config.TokenClaim{ + { + Name: "groups", + }, + }, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + claims := map[string]interface{}{ + "groups": []interface{}{"admin", "notadmin"}, + } + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), nil, nil, claims) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.NoError(t, err) + require.Equal(t, 1, handler.GetHandleCallCount()) + + isolationCtx := isolation.IsolationContextFromContext(handler.GetCapturedCtx()) + require.Len(t, isolationCtx.GetResourceScopes(), 1) + + resourceScope := isolationCtx.GetResourceScopes()[0] + expectedResourceScope := isolation.ResourceScope{ + Project: "flytesnacks", + Domain: "development", + } + require.Equal(t, expectedResourceScope, resourceScope) + }) + + t.Run("authorization fails with string claim based roles resolution", func(t *testing.T) { + + cfg := &config.Config{ + Rbac: config.Rbac{ + Policies: []config.AuthorizationPolicy{ + adminAuthPolicy, + }, + TokenClaimRoleResolver: config.TokenClaimRoleResolver{ + Enabled: true, + TokenClaims: []config.TokenClaim{ + { + Name: "groups", + }, + }, + }, + }, + } + authCtx := &mocks.AuthenticationContext{} + authCtx.OnOptions().Return(cfg) + + interceptor, err := GetAuthorizationInterceptor(authCtx) + require.NoError(t, err) + + handler := &interceptorstest.TestUnaryHandler{} + + claims := map[string]interface{}{ + "groups": []interface{}{"notadmin"}, + } + tokenIdentityContext, err := auth.NewIdentityContext("", "", "", time.Now(), nil, nil, claims) + ctxWithIdentity := tokenIdentityContext.WithContext(ctx) + require.NoError(t, err) + _, err = interceptor(ctxWithIdentity, nil, info, handler.Handle) + require.ErrorIs(t, err, status.Errorf(codes.PermissionDenied, "")) + require.Equal(t, 0, handler.GetHandleCallCount()) + }) +} diff --git a/flyteadmin/auth/isolation/isolation_context.go b/flyteadmin/auth/isolation/isolation_context.go new file mode 100644 index 0000000000..8479ccceef --- /dev/null +++ b/flyteadmin/auth/isolation/isolation_context.go @@ -0,0 +1,59 @@ +// This is its own package to resolve a circular dependency issues between auth -> common -> auth +package isolation + +import ( + "context" + + "github.com/flyteorg/flyte/flytestdlib/contextutils" +) + +const ( + ContextKeyIsolationContext = contextutils.Key("isolation_context") +) + +// An IsolationContext provides context around how to isolate or filter resources for a particular API request. +type IsolationContext struct { + resourceScopes []ResourceScope +} + +func NewIsolationContext(resourceScopes []ResourceScope) IsolationContext { + return IsolationContext{resourceScopes: resourceScopes} +} + +// GetResourceScopes gets the scope of resources that the client can access. These are typically Flyte projects as well +// as domains within projects. +func (c IsolationContext) GetResourceScopes() []ResourceScope { + return c.resourceScopes +} + +// ResourceScope is a hierarchical representation of what scope of resources a user has access to. Project -> Domain. +// Empty strings are considered wildcard access. +type ResourceScope struct { + Project string + Domain string +} + +// TargetResourceScopeDepth represents the depth/scope of an individual resource. Sometimes only project level scope is +// applicable to a resource while the user has domain level scope within a project. In such cases, a user's resource +// scope may need to be truncated to match the depth of the target resource scope during filtering operations. +type TargetResourceScopeDepth = int + +const ( + ProjectTargetResourceScopeDepth = 0 // The resource can only be filtered at a project level + DomainTargetResourceScopeDepth = 1 // THe resource can be filtered at both a project and domain level +) + +// WithContext adds the isolation context to the go context +func (c IsolationContext) WithContext(ctx context.Context) context.Context { + return context.WithValue(ctx, ContextKeyIsolationContext, c) +} + +// IsolationContextFromContext extracts the isolation context from a go context +func IsolationContextFromContext(ctx context.Context) IsolationContext { + existing := ctx.Value(ContextKeyIsolationContext) + if existing != nil { + return existing.(IsolationContext) + } + + return NewIsolationContext([]ResourceScope{}) +} diff --git a/flyteadmin/pkg/common/filters.go b/flyteadmin/pkg/common/filters.go index d20ef21739..d8daa603f0 100644 --- a/flyteadmin/pkg/common/filters.go +++ b/flyteadmin/pkg/common/filters.go @@ -4,6 +4,8 @@ package common import ( "context" "fmt" + "github.com/flyteorg/flyte/flyteadmin/auth/isolation" + "gorm.io/gorm" "google.golang.org/grpc/codes" @@ -389,3 +391,62 @@ func NewWithDefaultValueFilter(defaultValue interface{}, filter InlineFilter) (I defaultValue: defaultValue, }, nil } + +// ResourceColumns is a struct to indicate which columns in a query represent the flyte project and domain +type ResourceColumns struct { + Project string + Domain string +} + +// IsolationFilter is an interface for filtering data based on authorization rules +type IsolationFilter interface { + GetScopes() []func(tx *gorm.DB) *gorm.DB +} + +// ResourceIsolationFilter is an implementation of IsolationFilter that provides db scopes for filtering resources. +type ResourceIsolationFilter struct { + resourceScopes []isolation.ResourceScope + columns ResourceColumns +} + +// NewResourceIsolationFilter creates a new ResourceIsolationFilter +func NewResourceIsolationFilter(resourceScopes []isolation.ResourceScope, columns ResourceColumns) *ResourceIsolationFilter { + return &ResourceIsolationFilter{resourceScopes: resourceScopes, columns: columns} +} + +// GetScopes gets a list of functions that mutates a query to apply filtering where clauses +func (a *ResourceIsolationFilter) GetScopes() []func(tx *gorm.DB) *gorm.DB { + scopes := []func(tx *gorm.DB) *gorm.DB{} + + for tempIndex, tempResourceScope := range a.resourceScopes { + // Copy to avoid issues + index := tempIndex + resourceScope := tempResourceScope + scopes = append(scopes, func(tx *gorm.DB) *gorm.DB { + // We must start of with a where clause, then after that we can chain each resource scope with OR + if index == 0 { + return tx.Where(tx.Scopes(a.toDbScopes(resourceScope))) + } + return tx.Or(tx.Scopes(a.toDbScopes(resourceScope))) + }) + } + + return scopes +} + +func (a *ResourceIsolationFilter) toDbScopes(resourceScope isolation.ResourceScope) func(db *gorm.DB) *gorm.DB { + return func(tx *gorm.DB) *gorm.DB { + + // Filter resource by project + if resourceScope.Project != "" { + tx = tx.Where(fmt.Sprintf("%s = ?", a.columns.Project), resourceScope.Project) + } + + // Filter resource by domain + if resourceScope.Domain != "" { + tx = tx.Where(fmt.Sprintf("%s = ?", a.columns.Domain), resourceScope.Domain) + } + + return tx + } +} diff --git a/flyteadmin/pkg/manager/impl/util/filters.go b/flyteadmin/pkg/manager/impl/util/filters.go index b6426a3852..4e17fd946a 100644 --- a/flyteadmin/pkg/manager/impl/util/filters.go +++ b/flyteadmin/pkg/manager/impl/util/filters.go @@ -4,6 +4,7 @@ package util import ( "context" "fmt" + "github.com/flyteorg/flyte/flyteadmin/auth/isolation" "regexp" "strconv" "strings" @@ -318,3 +319,88 @@ func GetNodeExecutionIdentifierFilters( } return append(workflowExecutionIdentifierFilters, nodeIDFilter), nil } + +// GetIsolationFilter takes in a target resource depth, a user's resource scopes, and the resource's columns and generates +// appropriate where clauses to filter resources based on a user's isolation context. +func GetIsolationFilter(ctx context.Context, resourceDepth isolation.TargetResourceScopeDepth, columns common.ResourceColumns) common.IsolationFilter { + authzCtx := isolation.IsolationContextFromContext(ctx) + resourceScopes := authzCtx.GetResourceScopes() + + // authz is disabled or has been bypassed + if len(resourceScopes) == 0 { + return nil + } + + adjustedResourceScopes := []isolation.ResourceScope{} + + for _, resourceScope := range resourceScopes { + if resourceScope.Project == "" { + // User has wilcard access + return nil + } + + tempResourceScope := isolation.ResourceScope{ + Project: resourceScope.Project, + } + + if resourceDepth == isolation.ProjectTargetResourceScopeDepth { + // truncate user's resource scope to project level + adjustedResourceScopes = append(adjustedResourceScopes, tempResourceScope) + continue + } + + if resourceScope.Domain != "" { + tempResourceScope.Domain = resourceScope.Domain + } + + if resourceDepth == isolation.DomainTargetResourceScopeDepth { + // truncate user's resource scope to domain level + adjustedResourceScopes = append(adjustedResourceScopes, tempResourceScope) + continue + } + } + + return common.NewResourceIsolationFilter(adjustedResourceScopes, columns) +} + +// FilterResourceMutation is a helper function to determine whether a use can create/modify/delete a target resource. +func FilterResourceMutation(ctx context.Context, targetProject string, targetDomain string) error { + authzCtx := isolation.IsolationContextFromContext(ctx) + resourceScopes := authzCtx.GetResourceScopes() + + if len(resourceScopes) == 0 { + // If there are no resource scopes rbac is likely disabled or the request has been whitelisted. Allow it. + return nil + } + + for _, resourceScope := range resourceScopes { + if targetProject != "" { + if resourceScope.Project == "" { + // user has wildcard scope + return nil + } + + if resourceScope.Project != targetProject { + // Project depth doesn't match + continue + } + } + + if targetDomain != "" { + if resourceScope.Domain == "" { + // user has wildcard scope + return nil + } + + if resourceScope.Domain != targetDomain { + // Domain depth doesn't match + continue + } + } + + // Exact match all the way, allow it. + return nil + } + + return errors.NewFlyteAdminErrorf(codes.NotFound, "project or domain not found") +} diff --git a/flyteadmin/pkg/repositories/gormimpl/common.go b/flyteadmin/pkg/repositories/gormimpl/common.go index 7f4d4f370a..84f9db3209 100644 --- a/flyteadmin/pkg/repositories/gormimpl/common.go +++ b/flyteadmin/pkg/repositories/gormimpl/common.go @@ -80,7 +80,7 @@ func ValidateListInput(input interfaces.ListResourceInput) adminErrors.FlyteAdmi return nil } -func applyFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFilters []common.MapFilter) (*gorm.DB, error) { +func applyFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFilters []common.MapFilter, isolationFilter common.IsolationFilter) (*gorm.DB, error) { for _, filter := range inlineFilters { gormQueryExpr, err := filter.GetGormQueryExpr() if err != nil { @@ -91,6 +91,9 @@ func applyFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFilters [ for _, mapFilter := range mapFilters { tx = tx.Where(mapFilter.GetFilter()) } + if isolationFilter != nil { + tx = tx.Where(tx.Scopes(isolationFilter.GetScopes()...)) + } return tx, nil } diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 3f99172224..a30f8c8932 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -3,6 +3,9 @@ package gormimpl import ( "context" "errors" + "github.com/flyteorg/flyte/flyteadmin/auth/isolation" + "github.com/flyteorg/flyte/flyteadmin/pkg/common" + "github.com/flyteorg/flyte/flyteadmin/pkg/manager/impl/util" "gorm.io/gorm" @@ -13,6 +16,10 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) +var ( + taskColumnNames = common.ResourceColumns{Project: Project, Domain: Domain} +) + // Implementation of TaskRepoInterface. type TaskRepo struct { db *gorm.DB @@ -22,6 +29,11 @@ type TaskRepo struct { func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEntity *models.DescriptionEntity) error { timer := r.metrics.CreateDuration.Start() + + if err := util.FilterResourceMutation(ctx, input.Project, input.Domain); err != nil { + return err + } + err := r.db.WithContext(ctx).Transaction(func(_ *gorm.DB) error { if descriptionEntity == nil { tx := r.db.WithContext(ctx).Omit("id").Create(&input) @@ -49,6 +61,7 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() + isolationFilter := util.GetIsolationFilter(ctx, isolation.DomainTargetResourceScopeDepth, taskColumnNames) tx := r.db.WithContext(ctx).Where(&models.Task{ TaskKey: models.TaskKey{ Project: input.Project, @@ -56,7 +69,11 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models Name: input.Name, Version: input.Version, }, - }).Take(&task) + }) + if isolationFilter != nil { + tx = tx.Where(tx.Scopes(isolationFilter.GetScopes()...)) + } + tx = tx.Take(&task) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ @@ -75,6 +92,7 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models func (r *TaskRepo) List( ctx context.Context, input interfaces.ListResourceInput) (interfaces.TaskCollectionOutput, error) { + isolationFilter := util.GetIsolationFilter(ctx, isolation.DomainTargetResourceScopeDepth, taskColumnNames) // First validate input. if err := ValidateListInput(input); err != nil { return interfaces.TaskCollectionOutput{}, err @@ -82,7 +100,7 @@ func (r *TaskRepo) List( var tasks []models.Task tx := r.db.WithContext(ctx).Limit(input.Limit).Offset(input.Offset) // Apply filters - tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters) + tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters, isolationFilter) if err != nil { return interfaces.TaskCollectionOutput{}, err } @@ -110,10 +128,12 @@ func (r *TaskRepo) ListTaskIdentifiers(ctx context.Context, input interfaces.Lis return interfaces.TaskCollectionOutput{}, err } + isolationFilter := util.GetIsolationFilter(ctx, isolation.DomainTargetResourceScopeDepth, taskColumnNames) + tx := r.db.WithContext(ctx).Model(models.Task{}).Limit(input.Limit).Offset(input.Offset) // Apply filters - tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters) + tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters, isolationFilter) if err != nil { return interfaces.TaskCollectionOutput{}, err }