Skip to content

Commit

Permalink
Return all supported entitlements for owned value's fields
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Jul 7, 2023
1 parent 899a0fa commit a54cad8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 28 deletions.
30 changes: 26 additions & 4 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5215,10 +5215,17 @@ func (interpreter *Interpreter) getAccessOfMember(self Value, identifier string)
return &member.Resolve(interpreter, identifier, ast.EmptyRange, func(err error) {}).Access
}

func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAccess *sema.Access, resultValue Value) Value {
func (interpreter *Interpreter) mapMemberValueAuthorization(
self Value,
memberAccess *sema.Access,
resultValue Value,
resultingType sema.Type,
) Value {

if memberAccess == nil {
return resultValue
}

if mappedAccess, isMappedAccess := (*memberAccess).(sema.EntitlementMapAccess); isMappedAccess {
var auth Authorization
switch selfValue := self.(type) {
Expand All @@ -5230,7 +5237,16 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc
}
auth = ConvertSemaAccesstoStaticAuthorization(interpreter, imageAccess)
default:
auth = ConvertSemaAccesstoStaticAuthorization(interpreter, mappedAccess.Codomain())
var access sema.Access
if mappedAccess.Type == sema.IdentityMappingType {
access = sema.AllSupportedEntitlements(resultingType)
}

if access == nil {
access = mappedAccess.Codomain()
}

auth = ConvertSemaAccesstoStaticAuthorization(interpreter, access)
}

switch refValue := resultValue.(type) {
Expand All @@ -5245,15 +5261,21 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc
return resultValue
}

func (interpreter *Interpreter) getMemberWithAuthMapping(self Value, locationRange LocationRange, identifier string) Value {
func (interpreter *Interpreter) getMemberWithAuthMapping(
self Value,
locationRange LocationRange,
identifier string,
memberAccessInfo sema.MemberAccessInfo,
) Value {

result := interpreter.getMember(self, locationRange, identifier)
if result == nil {
return nil
}
// once we have obtained the member, if it was declared with entitlement-mapped access, we must compute the output of the map based
// on the runtime authorizations of the accessing reference or composite
memberAccess := interpreter.getAccessOfMember(self, identifier)
return interpreter.mapMemberValueAuthorization(self, memberAccess, result)
return interpreter.mapMemberValueAuthorization(self, memberAccess, result, memberAccessInfo.ResultingType)
}

// getMember gets the member value by the given identifier from the given Value depending on its type.
Expand Down
3 changes: 2 additions & 1 deletion runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a
if isNestedResourceMove {
resultValue = target.(MemberAccessibleValue).RemoveMember(interpreter, locationRange, identifier)
} else {
resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier)
resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier, memberAccessInfo)
}

if resultValue == nil && !allowMissing {
panic(UseBeforeInitializationError{
Name: identifier,
Expand Down
36 changes: 18 additions & 18 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,19 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT
)
}

// Check access and report if inaccessible
accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) }
isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, resultingType, accessRange)
if !isReadable {
checker.report(
&InvalidAccessError{
Name: member.Identifier.Identifier,
RestrictingAccess: member.Access,
DeclarationKind: member.DeclarationKind,
Range: accessRange(),
},
)
}
// Check access and report if inaccessible
accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) }
isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, resultingType, accessRange)
if !isReadable {
checker.report(
&InvalidAccessError{
Name: member.Identifier.Identifier,
RestrictingAccess: member.Access,
DeclarationKind: member.DeclarationKind,
Range: accessRange(),
},
)
}

// the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type
// i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`,
Expand Down Expand Up @@ -451,7 +451,7 @@ func (checker *Checker) mapAccess(

default:
if mappedAccess.Type == IdentityMappingType {
access := allSupportedEntitlements(resultingType)
access := AllSupportedEntitlements(resultingType)
if access != nil {
return true, access
}
Expand All @@ -463,14 +463,14 @@ func (checker *Checker) mapAccess(
}
}

func allSupportedEntitlements(typ Type) Access {
func AllSupportedEntitlements(typ Type) Access {
switch typ := typ.(type) {
case *ReferenceType:
return allSupportedEntitlements(typ.Type)
return AllSupportedEntitlements(typ.Type)
case *OptionalType:
return allSupportedEntitlements(typ.Type)
return AllSupportedEntitlements(typ.Type)
case *FunctionType:
return allSupportedEntitlements(typ.ReturnTypeAnnotation.Type)
return AllSupportedEntitlements(typ.ReturnTypeAnnotation.Type)
case EntitlementSupportingType:
supportedEntitlements := typ.SupportedEntitlements()
if supportedEntitlements != nil && supportedEntitlements.Len() > 0 {
Expand Down
6 changes: 3 additions & 3 deletions runtime/tests/checker/entitlements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5505,11 +5505,11 @@ func TestCheckIdentityMapping(t *testing.T) {
let ref1: auth(A, B, C) &X = y.x1
//let ref2: auth(A, B, C) &X? = y.x2
let ref2: auth(A, B, C) &X? = y.x2
//let ref3: auth(A, B, C) &X = y.getX()
let ref3: auth(A, B, C) &X = y.getX()
//let ref4: auth(A, B, C) &X? = y.getOptionalX()
let ref4: auth(A, B, C) &X? = y.getOptionalX()
}
`)

Expand Down
2 changes: 0 additions & 2 deletions runtime/tests/interpreter/entitlements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3090,11 +3090,9 @@ func TestInterpretIdentityMapping(t *testing.T) {
struct X {
access(A | B) var s: String
init() {
self.s = "hello"
}
access(C) fun foo() {}
}
Expand Down

0 comments on commit a54cad8

Please sign in to comment.