From a54cad8da9841493e9f5002869e07a0518a908b9 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Thu, 6 Jul 2023 13:34:44 -0700 Subject: [PATCH] Return all supported entitlements for owned value's fields --- runtime/interpreter/interpreter.go | 30 +++++++++++++--- runtime/interpreter/interpreter_expression.go | 3 +- runtime/sema/check_member_expression.go | 36 +++++++++---------- runtime/tests/checker/entitlements_test.go | 6 ++-- .../tests/interpreter/entitlements_test.go | 2 -- 5 files changed, 49 insertions(+), 28 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index c84ebc8101..c5823954a0 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -5215,10 +5215,17 @@ func (interpreter *Interpreter) getAccessOfMember(self Value, identifier string) return &member.Resolve(interpreter, identifier, ast.EmptyRange, func(err error) {}).Access } -func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAccess *sema.Access, resultValue Value) Value { +func (interpreter *Interpreter) mapMemberValueAuthorization( + self Value, + memberAccess *sema.Access, + resultValue Value, + resultingType sema.Type, +) Value { + if memberAccess == nil { return resultValue } + if mappedAccess, isMappedAccess := (*memberAccess).(sema.EntitlementMapAccess); isMappedAccess { var auth Authorization switch selfValue := self.(type) { @@ -5230,7 +5237,16 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc } auth = ConvertSemaAccesstoStaticAuthorization(interpreter, imageAccess) default: - auth = ConvertSemaAccesstoStaticAuthorization(interpreter, mappedAccess.Codomain()) + var access sema.Access + if mappedAccess.Type == sema.IdentityMappingType { + access = sema.AllSupportedEntitlements(resultingType) + } + + if access == nil { + access = mappedAccess.Codomain() + } + + auth = ConvertSemaAccesstoStaticAuthorization(interpreter, access) } switch refValue := resultValue.(type) { @@ -5245,7 +5261,13 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc return resultValue } -func (interpreter *Interpreter) getMemberWithAuthMapping(self Value, locationRange LocationRange, identifier string) Value { +func (interpreter *Interpreter) getMemberWithAuthMapping( + self Value, + locationRange LocationRange, + identifier string, + memberAccessInfo sema.MemberAccessInfo, +) Value { + result := interpreter.getMember(self, locationRange, identifier) if result == nil { return nil @@ -5253,7 +5275,7 @@ func (interpreter *Interpreter) getMemberWithAuthMapping(self Value, locationRan // once we have obtained the member, if it was declared with entitlement-mapped access, we must compute the output of the map based // on the runtime authorizations of the accessing reference or composite memberAccess := interpreter.getAccessOfMember(self, identifier) - return interpreter.mapMemberValueAuthorization(self, memberAccess, result) + return interpreter.mapMemberValueAuthorization(self, memberAccess, result, memberAccessInfo.ResultingType) } // getMember gets the member value by the given identifier from the given Value depending on its type. diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index f306c5372a..df3c2cee69 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -202,8 +202,9 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a if isNestedResourceMove { resultValue = target.(MemberAccessibleValue).RemoveMember(interpreter, locationRange, identifier) } else { - resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier) + resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier, memberAccessInfo) } + if resultValue == nil && !allowMissing { panic(UseBeforeInitializationError{ Name: identifier, diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index c09ac9c087..e1a8ee257f 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -283,19 +283,19 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT ) } - // Check access and report if inaccessible - accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) } - isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, resultingType, accessRange) - if !isReadable { - checker.report( - &InvalidAccessError{ - Name: member.Identifier.Identifier, - RestrictingAccess: member.Access, - DeclarationKind: member.DeclarationKind, - Range: accessRange(), - }, - ) - } + // Check access and report if inaccessible + accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) } + isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, resultingType, accessRange) + if !isReadable { + checker.report( + &InvalidAccessError{ + Name: member.Identifier.Identifier, + RestrictingAccess: member.Access, + DeclarationKind: member.DeclarationKind, + Range: accessRange(), + }, + ) + } // the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type // i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`, @@ -451,7 +451,7 @@ func (checker *Checker) mapAccess( default: if mappedAccess.Type == IdentityMappingType { - access := allSupportedEntitlements(resultingType) + access := AllSupportedEntitlements(resultingType) if access != nil { return true, access } @@ -463,14 +463,14 @@ func (checker *Checker) mapAccess( } } -func allSupportedEntitlements(typ Type) Access { +func AllSupportedEntitlements(typ Type) Access { switch typ := typ.(type) { case *ReferenceType: - return allSupportedEntitlements(typ.Type) + return AllSupportedEntitlements(typ.Type) case *OptionalType: - return allSupportedEntitlements(typ.Type) + return AllSupportedEntitlements(typ.Type) case *FunctionType: - return allSupportedEntitlements(typ.ReturnTypeAnnotation.Type) + return AllSupportedEntitlements(typ.ReturnTypeAnnotation.Type) case EntitlementSupportingType: supportedEntitlements := typ.SupportedEntitlements() if supportedEntitlements != nil && supportedEntitlements.Len() > 0 { diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index b1ee376774..3216be8a69 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -5505,11 +5505,11 @@ func TestCheckIdentityMapping(t *testing.T) { let ref1: auth(A, B, C) &X = y.x1 - //let ref2: auth(A, B, C) &X? = y.x2 + let ref2: auth(A, B, C) &X? = y.x2 - //let ref3: auth(A, B, C) &X = y.getX() + let ref3: auth(A, B, C) &X = y.getX() - //let ref4: auth(A, B, C) &X? = y.getOptionalX() + let ref4: auth(A, B, C) &X? = y.getOptionalX() } `) diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 937307d5c7..7d7313a6ad 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -3090,11 +3090,9 @@ func TestInterpretIdentityMapping(t *testing.T) { struct X { access(A | B) var s: String - init() { self.s = "hello" } - access(C) fun foo() {} }