diff --git a/runtime/account_test.go b/runtime/account_test.go index 8966e94abb..d36f94ac48 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -1451,10 +1451,23 @@ func TestRuntimePublicKey(t *testing.T) { } addPublicKeyValidation(runtimeInterface, nil) - _, err := executeScript(script, runtimeInterface) - errs := checker.RequireCheckerErrors(t, err, 1) + value, err := executeScript(script, runtimeInterface) + require.NoError(t, err) - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + expected := cadence.Struct{ + StructType: PublicKeyType, + Fields: []cadence.Value{ + // Public key (bytes) + newBytesValue([]byte{1, 2}), + + // Signature Algo + newSignAlgoValue(sema.SignatureAlgorithmECDSA_P256), + }, + } + + expected = cadence.ValueWithCachedTypeID(expected) + + assert.Equal(t, expected, value) }) t.Run("raw-key reference mutability", func(t *testing.T) { @@ -1467,7 +1480,7 @@ func TestRuntimePublicKey(t *testing.T) { signatureAlgorithm: SignatureAlgorithm.ECDSA_P256 ) - var publickeyRef = &publicKey.publicKey as &[UInt8] + var publickeyRef = &publicKey.publicKey as auth(Mutate) &[UInt8] publickeyRef[0] = 3 return publicKey @@ -1885,9 +1898,7 @@ func TestAuthAccountContracts(t *testing.T) { Location: nextTransactionLocation(), }, ) - errs := checker.RequireCheckerErrors(t, err, 1) - - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + require.NoError(t, err) }) t.Run("update names through reference", func(t *testing.T) { @@ -1898,7 +1909,7 @@ func TestAuthAccountContracts(t *testing.T) { script := []byte(` transaction { prepare(signer: AuthAccount) { - var namesRef = &signer.contracts.names as &[String] + var namesRef = &signer.contracts.names as auth(Mutate) &[String] namesRef[0] = "baz" assert(signer.contracts.names[0] == "foo") @@ -2097,7 +2108,7 @@ func TestPublicAccountContracts(t *testing.T) { }, } - _, err := rt.ExecuteScript( + result, err := rt.ExecuteScript( Script{ Source: script, }, @@ -2106,9 +2117,14 @@ func TestPublicAccountContracts(t *testing.T) { Location: common.ScriptLocation{}, }, ) - errs := checker.RequireCheckerErrors(t, err, 1) + require.NoError(t, err) + + require.IsType(t, cadence.Array{}, result) + array := result.(cadence.Array) - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + require.Len(t, array.Values, 2) + assert.Equal(t, cadence.String("foo"), array.Values[0]) + assert.Equal(t, cadence.String("bar"), array.Values[1]) }) t.Run("append names", func(t *testing.T) { @@ -2133,7 +2149,7 @@ func TestPublicAccountContracts(t *testing.T) { }, } - _, err := rt.ExecuteScript( + result, err := rt.ExecuteScript( Script{ Source: script, }, @@ -2142,9 +2158,14 @@ func TestPublicAccountContracts(t *testing.T) { Location: common.ScriptLocation{}, }, ) - errs := checker.RequireCheckerErrors(t, err, 1) + require.NoError(t, err) + + require.IsType(t, cadence.Array{}, result) + array := result.(cadence.Array) - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + require.Len(t, array.Values, 2) + assert.Equal(t, cadence.String("foo"), array.Values[0]) + assert.Equal(t, cadence.String("bar"), array.Values[1]) }) } diff --git a/runtime/attachments_test.go b/runtime/attachments_test.go index 07adf640fb..de0ede93e7 100644 --- a/runtime/attachments_test.go +++ b/runtime/attachments_test.go @@ -171,10 +171,20 @@ func TestAccountAttachmentExportFailure(t *testing.T) { import Test from 0x1 access(all) fun main(): &Test.A? { let r <- Test.makeRWithA() - let a = r[Test.A] + var a = r[Test.A] + + // Life span of attachments (references) are validated statically. + // This indirection helps to trick the checker and causes to perform the validation at runtime, + // which is the intention of this test. + a = returnSameRef(a) + destroy r return a } + + access(all) fun returnSameRef(_ ref: &Test.A?): &Test.A? { + return ref + } `) runtimeInterface1 := &testRuntimeInterface{ diff --git a/runtime/capabilitycontrollers_test.go b/runtime/capabilitycontrollers_test.go index 4268e9ee62..b73273b4fc 100644 --- a/runtime/capabilitycontrollers_test.go +++ b/runtime/capabilitycontrollers_test.go @@ -82,7 +82,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { /// > Our version of quicksort is not the fastest possible, /// > but it's one of the simplest. /// - access(all) fun quickSort(_ items: &[AnyStruct], isLess: fun(Int, Int): Bool) { + access(all) fun quickSort(_ items: auth(Mutate) &[AnyStruct], isLess: fun(Int, Int): Bool) { fun quickSortPart(leftIndex: Int, rightIndex: Int) { @@ -92,6 +92,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { let pivotIndex = (leftIndex + rightIndex) / 2 + items[pivotIndex] <-> items[leftIndex] items[pivotIndex] <-> items[leftIndex] var lastIndex = leftIndex @@ -1195,7 +1196,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { assert(controllers1.length == 3) Test.quickSort( - &controllers1 as &[AnyStruct], + &controllers1 as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers1[i] let b = controllers1[j] @@ -1293,7 +1294,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { assert(controllers1.length == 3) Test.quickSort( - &controllers1 as &[AnyStruct], + &controllers1 as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers1[i] let b = controllers1[j] @@ -1644,7 +1645,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { assert(controllers.length == 3) Test.quickSort( - &controllers as &[AnyStruct], + &controllers as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers[i] let b = controllers[j] @@ -1722,7 +1723,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { assert(controllers.length == 3) Test.quickSort( - &controllers as &[AnyStruct], + &controllers as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers[i] let b = controllers[j] @@ -1977,7 +1978,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { let controllers1Before = signer.capabilities.storage.getControllers(forPath: storagePath1) Test.quickSort( - &controllers1Before as &[AnyStruct], + &controllers1Before as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers1Before[i] let b = controllers1Before[j] @@ -1991,7 +1992,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { let controllers2Before = signer.capabilities.storage.getControllers(forPath: storagePath2) Test.quickSort( - &controllers2Before as &[AnyStruct], + &controllers2Before as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers2Before[i] let b = controllers2Before[j] @@ -2015,7 +2016,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { let controllers1After = signer.capabilities.storage.getControllers(forPath: storagePath1) Test.quickSort( - &controllers1After as &[AnyStruct], + &controllers1After as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers1After[i] let b = controllers1After[j] @@ -2028,7 +2029,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { let controllers2After = signer.capabilities.storage.getControllers(forPath: storagePath2) Test.quickSort( - &controllers2After as &[AnyStruct], + &controllers2After as auth(Mutate) &[AnyStruct], isLess: fun(i: Int, j: Int): Bool { let a = controllers2After[i] let b = controllers2After[j] diff --git a/runtime/convertValues_test.go b/runtime/convertValues_test.go index 961ca19875..c41b67a039 100644 --- a/runtime/convertValues_test.go +++ b/runtime/convertValues_test.go @@ -1912,7 +1912,7 @@ func TestExportReferenceValue(t *testing.T) { var v:[AnyStruct] = [] acct.save(v, to: /storage/x) - var ref = acct.borrow<&[AnyStruct]>(from: /storage/x)! + var ref = acct.borrow(from: /storage/x)! ref.append(ref) return ref } @@ -1947,7 +1947,7 @@ func TestExportReferenceValue(t *testing.T) { var v:[AnyStruct] = [] acct.save(v, to: /storage/x) - var ref1 = acct.borrow<&[AnyStruct]>(from: /storage/x)! + var ref1 = acct.borrow(from: /storage/x)! var ref2 = acct.borrow<&[AnyStruct]>(from: /storage/x)! ref1.append(ref2) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 02fdbec8c6..bb0903067b 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4451,7 +4451,22 @@ var AuthAccountReferenceStaticType = ReferenceStaticType{ } func (interpreter *Interpreter) getEntitlement(typeID common.TypeID) (*sema.EntitlementType, error) { - location, _, _ := common.DecodeTypeID(interpreter, string(typeID)) + location, qualifiedIdentifier, err := common.DecodeTypeID(interpreter, string(typeID)) + if err != nil { + return nil, err + } + + if location == nil { + ty := sema.BuiltinEntitlements[qualifiedIdentifier] + if ty == nil { + return nil, TypeLoadingError{ + TypeID: typeID, + } + } + + return ty, nil + } + elaboration := interpreter.getElaboration(location) if elaboration == nil { return nil, TypeLoadingError{ @@ -4470,7 +4485,22 @@ func (interpreter *Interpreter) getEntitlement(typeID common.TypeID) (*sema.Enti } func (interpreter *Interpreter) getEntitlementMapType(typeID common.TypeID) (*sema.EntitlementMapType, error) { - location, _, _ := common.DecodeTypeID(interpreter, string(typeID)) + location, qualifiedIdentifier, err := common.DecodeTypeID(interpreter, string(typeID)) + if err != nil { + return nil, err + } + + if location == nil { + ty := sema.BuiltinEntitlementMappings[qualifiedIdentifier] + if ty == nil { + return nil, TypeLoadingError{ + TypeID: typeID, + } + } + + return ty, nil + } + elaboration := interpreter.getElaboration(location) if elaboration == nil { return nil, TypeLoadingError{ @@ -4691,10 +4721,17 @@ func (interpreter *Interpreter) getAccessOfMember(self Value, identifier string) return &member.Resolve(interpreter, identifier, ast.EmptyRange, func(err error) {}).Access } -func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAccess *sema.Access, resultValue Value) Value { +func (interpreter *Interpreter) mapMemberValueAuthorization( + self Value, + memberAccess *sema.Access, + resultValue Value, + resultingType sema.Type, +) Value { + if memberAccess == nil { return resultValue } + if mappedAccess, isMappedAccess := (*memberAccess).(sema.EntitlementMapAccess); isMappedAccess { var auth Authorization switch selfValue := self.(type) { @@ -4706,7 +4743,16 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc } auth = ConvertSemaAccesstoStaticAuthorization(interpreter, imageAccess) default: - auth = ConvertSemaAccesstoStaticAuthorization(interpreter, mappedAccess.Codomain()) + var access sema.Access + if mappedAccess.Type == sema.IdentityMappingType { + access = sema.AllSupportedEntitlements(resultingType) + } + + if access == nil { + access = mappedAccess.Codomain() + } + + auth = ConvertSemaAccesstoStaticAuthorization(interpreter, access) } switch refValue := resultValue.(type) { @@ -4721,15 +4767,21 @@ func (interpreter *Interpreter) mapMemberValueAuthorization(self Value, memberAc return resultValue } -func (interpreter *Interpreter) getMemberWithAuthMapping(self Value, locationRange LocationRange, identifier string) Value { +func (interpreter *Interpreter) getMemberWithAuthMapping( + self Value, + locationRange LocationRange, + identifier string, + memberAccessInfo sema.MemberAccessInfo, +) Value { + result := interpreter.getMember(self, locationRange, identifier) if result == nil { return nil } // once we have obtained the member, if it was declared with entitlement-mapped access, we must compute the output of the map based - // on the runtime authorizations of the acccessing reference or composite + // on the runtime authorizations of the accessing reference or composite memberAccess := interpreter.getAccessOfMember(self, identifier) - return interpreter.mapMemberValueAuthorization(self, memberAccess, result) + return interpreter.mapMemberValueAuthorization(self, memberAccess, result, memberAccessInfo.ResultingType) } // getMember gets the member value by the given identifier from the given Value depending on its type. diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 417e588184..f185de3b7e 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -144,7 +144,10 @@ func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression if isNestedResourceMove { return target.RemoveKey(interpreter, locationRange, transferredIndexingValue) } else { - return target.GetKey(interpreter, locationRange, transferredIndexingValue) + value := target.GetKey(interpreter, locationRange, transferredIndexingValue) + + // If the indexing value is a reference, then return a reference for the resulting value. + return interpreter.maybeGetReference(indexExpression, value) } }, set: func(value Value) { @@ -169,6 +172,11 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a isNestedResourceMove := interpreter.Program.Elaboration.IsNestedResourceMoveExpression(memberExpression) + memberAccessInfo, ok := interpreter.Program.Elaboration.MemberExpressionMemberAccessInfo(memberExpression) + if !ok { + panic(errors.NewUnreachableError()) + } + return getterSetter{ target: target, get: func(allowMissing bool) Value { @@ -194,8 +202,9 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a if isNestedResourceMove { resultValue = target.(MemberAccessibleValue).RemoveMember(interpreter, locationRange, identifier) } else { - resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier) + resultValue = interpreter.getMemberWithAuthMapping(target, locationRange, identifier, memberAccessInfo) } + if resultValue == nil && !allowMissing { panic(UseBeforeInitializationError{ Name: identifier, @@ -212,6 +221,13 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a } } + // Return a reference, if the member is accessed via a reference. + // This is pre-computed at the checker. + if memberAccessInfo.ReturnReference { + // Get a reference to the value + resultValue = interpreter.getReferenceValue(resultValue, memberAccessInfo.ResultingType) + } + return resultValue }, set: func(value Value) { @@ -222,12 +238,51 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a } } +// getReferenceValue Returns a reference to a given value. +// Reference to an optional should return an optional reference. +// This has to be done recursively for nested optionals. +// e.g.1: Given type T, this method returns &T. +// e.g.2: Given T?, this returns (&T)? +func (interpreter *Interpreter) getReferenceValue(value Value, resultType sema.Type) Value { + switch value := value.(type) { + case NilValue, ReferenceValue: + // Reference to a nil, should return a nil. + // If the value is already a reference then return the same reference. + return value + case *SomeValue: + innerValue := interpreter.getReferenceValue(value.value, resultType) + return NewSomeValueNonCopying(interpreter, innerValue) + } + + // `resultType` is always an [optional] reference. + // This is guaranteed by the checker. + referenceType, ok := sema.UnwrapOptionalType(resultType).(*sema.ReferenceType) + if !ok { + panic(errors.NewUnreachableError()) + } + + auth := interpreter.getEffectiveAuthorization(referenceType) + + interpreter.maybeTrackReferencedResourceKindedValue(value) + return NewEphemeralReferenceValue(interpreter, auth, value, referenceType.Type) +} + +func (interpreter *Interpreter) getEffectiveAuthorization(referenceType *sema.ReferenceType) Authorization { + _, isMapped := referenceType.Authorization.(sema.EntitlementMapAccess) + + if isMapped && interpreter.SharedState.currentEntitlementMappedValue != nil { + return interpreter.SharedState.currentEntitlementMappedValue + } + + return ConvertSemaAccesstoStaticAuthorization(interpreter, referenceType.Authorization) +} + func (interpreter *Interpreter) checkMemberAccess( memberExpression *ast.MemberExpression, target Value, locationRange LocationRange, ) { - memberInfo, _ := interpreter.Program.Elaboration.MemberExpressionMemberInfo(memberExpression) + memberInfo, _ := interpreter.Program.Elaboration.MemberExpressionMemberAccessInfo(memberExpression) expectedType := memberInfo.AccessedType switch expectedType := expectedType.(type) { @@ -860,10 +915,28 @@ func (interpreter *Interpreter) VisitIndexExpression(expression *ast.IndexExpres Location: interpreter.Location, HasPosition: expression, } - return typedResult.GetKey(interpreter, locationRange, indexingValue) + value := typedResult.GetKey(interpreter, locationRange, indexingValue) + + // If the indexing value is a reference, then return a reference for the resulting value. + return interpreter.maybeGetReference(expression, value) } } +func (interpreter *Interpreter) maybeGetReference( + expression *ast.IndexExpression, + memberValue Value, +) Value { + indexExpressionTypes := interpreter.Program.Elaboration.IndexExpressionTypes(expression) + if indexExpressionTypes.ReturnReference { + expectedType := indexExpressionTypes.ResultType + + // Get a reference to the value + memberValue = interpreter.getReferenceValue(memberValue, expectedType) + } + + return memberValue +} + func (interpreter *Interpreter) VisitConditionalExpression(expression *ast.ConditionalExpression) Value { value, ok := interpreter.evalExpression(expression.Test).(BoolValue) if !ok { @@ -1155,15 +1228,9 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as interpreter.maybeTrackReferencedResourceKindedValue(result) makeReference := func(value Value, typ *sema.ReferenceType) *EphemeralReferenceValue { - var auth Authorization - // if we are currently interpretering a function that was declared with mapped entitlement access, any appearances // of that mapped access in the body of the function should be replaced with the computed output of the map - if _, isMapped := typ.Authorization.(sema.EntitlementMapAccess); isMapped && interpreter.SharedState.currentEntitlementMappedValue != nil { - auth = interpreter.SharedState.currentEntitlementMappedValue - } else { - auth = ConvertSemaAccesstoStaticAuthorization(interpreter, typ.Authorization) - } + auth := interpreter.getEffectiveAuthorization(typ) return NewEphemeralReferenceValue( interpreter, diff --git a/runtime/resource_duplicate_test.go b/runtime/resource_duplicate_test.go index 743cf198bf..69047f1237 100644 --- a/runtime/resource_duplicate_test.go +++ b/runtime/resource_duplicate_test.go @@ -67,10 +67,10 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { // --- this code actually makes use of the vuln --- access(all) resource DummyResource { - access(all) var dictRef: &{Bool: AnyResource}; - access(all) var arrRef: &[Vault]; + access(all) var dictRef: auth(Mutate) &{Bool: AnyResource}; + access(all) var arrRef: auth(Mutate) &[Vault]; access(all) var victim: @Vault; - init(dictRef: &{Bool: AnyResource}, arrRef: &[Vault], victim: @Vault) { + init(dictRef: auth(Mutate) &{Bool: AnyResource}, arrRef: auth(Mutate) &[Vault], victim: @Vault) { self.dictRef = dictRef; self.arrRef = arrRef; self.victim <- victim; @@ -85,8 +85,8 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { access(all) fun duplicateResource(victim1: @Vault, victim2: @Vault): @[Vault]{ let arr : @[Vault] <- []; let dict: @{Bool: DummyResource} <- { } - let ref = &dict as &{Bool: AnyResource}; - let arrRef = &arr as &[Vault]; + let ref = &dict as auth(Mutate) &{Bool: AnyResource}; + let arrRef = &arr as auth(Mutate) &[Vault]; var v1: @DummyResource? <- create DummyResource(dictRef: ref, arrRef: arrRef, victim: <- victim1); dict[false] <-> v1; @@ -168,7 +168,12 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { }, ) - require.ErrorAs(t, err, &interpreter.ContainerMutatedDuringIterationError{}) + var checkerErr *sema.CheckerError + require.ErrorAs(t, err, &checkerErr) + + errs := checker.RequireCheckerErrors(t, checkerErr, 2) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) }) t.Run("simplified", func(t *testing.T) { @@ -178,9 +183,9 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { script := ` access(all) resource Vault { access(all) var balance: UFix64 - access(all) var dictRef: &{Bool: Vault}; + access(all) var dictRef: auth(Mutate) &{Bool: Vault}; - init(balance: UFix64, _ dictRef: &{Bool: Vault}) { + init(balance: UFix64, _ dictRef: auth(Mutate) &{Bool: Vault}) { self.balance = balance self.dictRef = dictRef; } @@ -203,7 +208,7 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { access(all) fun main(): UFix64 { let dict: @{Bool: Vault} <- { } - let dictRef = &dict as &{Bool: Vault}; + let dictRef = &dict as auth(Mutate) &{Bool: Vault}; var v1 <- create Vault(balance: 1000.0, dictRef); // This will be duplicated var v2 <- create Vault(balance: 1.0, dictRef); // This will be lost @@ -268,10 +273,11 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { var checkerErr *sema.CheckerError require.ErrorAs(t, err, &checkerErr) - errs := checker.RequireCheckerErrors(t, checkerErr, 1) - - assert.IsType(t, &sema.InvalidatedResourceReferenceError{}, errs[0]) + errs := checker.RequireCheckerErrors(t, checkerErr, 3) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + assert.IsType(t, &sema.InvalidatedResourceReferenceError{}, errs[2]) }) t.Run("forEachKey", func(t *testing.T) { @@ -299,7 +305,7 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { let acc = getAuthAccount(0x1) acc.save(<-dict, to: /storage/foo) - let ref = acc.borrow<&{Int: R}>(from: /storage/foo)! + let ref = acc.borrow(from: /storage/foo)! ref.forEachKey(fun(i: Int): Bool { var r4: @R? <- create R() @@ -355,7 +361,8 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { }, ) - require.ErrorAs(t, err, &interpreter.ContainerMutatedDuringIterationError{}) + errs := checker.RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("array", func(t *testing.T) { @@ -365,9 +372,9 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { script := ` access(all) resource Vault { access(all) var balance: UFix64 - access(all) var arrRef: &[Vault] + access(all) var arrRef: auth(Mutate) &[Vault] - init(balance: UFix64, _ arrRef: &[Vault]) { + init(balance: UFix64, _ arrRef: auth(Mutate) &[Vault]) { self.balance = balance self.arrRef = arrRef; } @@ -390,7 +397,7 @@ func TestRuntimeResourceDuplicationUsingDestructorIteration(t *testing.T) { access(all) fun main(): UFix64 { let arr: @[Vault] <- [] - let arrRef = &arr as &[Vault]; + let arrRef = &arr as auth(Mutate) &[Vault]; var v1 <- create Vault(balance: 1000.0, arrRef); // This will be duplicated var v2 <- create Vault(balance: 1.0, arrRef); // This will be lost diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 66cdb3d2ed..8925cec867 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -2013,7 +2013,7 @@ func TestRuntimeStorageMultipleTransactionsResourceWithArray(t *testing.T) { prepare(signer: AuthAccount) { signer.save(<-createContainer(), to: /storage/container) - let cap = signer.capabilities.storage.issue<&Container>(/storage/container) + let cap = signer.capabilities.storage.issue(/storage/container) signer.capabilities.publish(cap, at: /public/container) } } @@ -2025,7 +2025,7 @@ func TestRuntimeStorageMultipleTransactionsResourceWithArray(t *testing.T) { transaction { prepare(signer: AuthAccount) { let publicAccount = getAccount(signer.address) - let ref = publicAccount.capabilities.borrow<&Container>(/public/container)! + let ref = publicAccount.capabilities.borrow(/public/container)! let length = ref.values.length ref.appendValue(1) @@ -2040,7 +2040,7 @@ func TestRuntimeStorageMultipleTransactionsResourceWithArray(t *testing.T) { transaction { prepare(signer: AuthAccount) { let publicAccount = getAccount(signer.address) - let ref = publicAccount.capabilities.borrow<&Container>(/public/container)! + let ref = publicAccount.capabilities.borrow(/public/container)! let length = ref.values.length ref.appendValue(2) @@ -7022,34 +7022,102 @@ func TestRuntimeInvalidContainerTypeConfusion(t *testing.T) { t.Parallel() - runtime := newTestInterpreterRuntime() + t.Run("invalid: auth account used as public account", func(t *testing.T) { + t.Parallel() - script := []byte(` + runtime := newTestInterpreterRuntime() + + script := []byte(` + access(all) fun main() { + let dict: {Int: PublicAccount} = {} + let ref = &dict as auth(Mutate) &{Int: AnyStruct} + ref[0] = getAuthAccount(0x01) as AnyStruct + } + `) + + runtimeInterface := &testRuntimeInterface{} + + _, err := runtime.ExecuteScript( + Script{ + Source: script, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + assertRuntimeErrorIsUserError(t, err) + + var typeErr interpreter.ContainerMutationError + require.ErrorAs(t, err, &typeErr) + }) + + t.Run("invalid: public account used as auth account", func(t *testing.T) { + + t.Parallel() + + runtime := newTestInterpreterRuntime() + + script := []byte(` access(all) fun main() { let dict: {Int: AuthAccount} = {} - let ref = &dict as &{Int: AnyStruct} + let ref = &dict as auth(Mutate) &{Int: AnyStruct} ref[0] = getAccount(0x01) as AnyStruct } `) - runtimeInterface := &testRuntimeInterface{} + runtimeInterface := &testRuntimeInterface{} - _, err := runtime.ExecuteScript( - Script{ - Source: script, - }, - Context{ - Interface: runtimeInterface, - Location: common.ScriptLocation{}, - }, - ) + _, err := runtime.ExecuteScript( + Script{ + Source: script, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) - RequireError(t, err) + RequireError(t, err) - assertRuntimeErrorIsUserError(t, err) + assertRuntimeErrorIsUserError(t, err) + + var typeErr interpreter.ContainerMutationError + require.ErrorAs(t, err, &typeErr) + }) + + t.Run("valid: public account used as public account", func(t *testing.T) { + + t.Parallel() + + runtime := newTestInterpreterRuntime() - var typeErr interpreter.ContainerMutationError - require.ErrorAs(t, err, &typeErr) + script := []byte(` + access(all) fun main() { + let dict: {Int: PublicAccount} = {} + let ref = &dict as auth(Mutate) &{Int: AnyStruct} + ref[0] = getAccount(0x01) as AnyStruct + } + `) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + } + + _, err := runtime.ExecuteScript( + Script{ + Source: script, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + require.NoError(t, err) + }) } func TestRuntimeStackOverflow(t *testing.T) { diff --git a/runtime/sema/access.go b/runtime/sema/access.go index af45f41028..488c22d746 100644 --- a/runtime/sema/access.go +++ b/runtime/sema/access.go @@ -339,6 +339,11 @@ func (e EntitlementMapAccess) entitlementImage(entitlement *EntitlementType) (ou // defined by the map in `e`, producing a new entitlement set of the image of the // arguments. func (e EntitlementMapAccess) Image(inputs Access, astRange func() ast.Range) (Access, error) { + + if e.Type == IdentityMappingType { + return inputs, nil + } + switch inputs := inputs.(type) { // primitive access always passes trivially through the map case PrimitiveAccess: diff --git a/runtime/sema/account_capability_controller.cdc b/runtime/sema/account_capability_controller.cdc index e1813cfe0d..3442c24560 100644 --- a/runtime/sema/account_capability_controller.cdc +++ b/runtime/sema/account_capability_controller.cdc @@ -1,4 +1,4 @@ -access(all) struct AccountCapabilityController { +access(all) struct AccountCapabilityController: ContainFields { /// An arbitrary "tag" for the controller. /// For example, it could be used to describe the purpose of the capability. diff --git a/runtime/sema/account_capability_controller.gen.go b/runtime/sema/account_capability_controller.gen.go index 2d2d0841a9..e66d6f8ca1 100644 --- a/runtime/sema/account_capability_controller.gen.go +++ b/runtime/sema/account_capability_controller.gen.go @@ -101,6 +101,7 @@ var AccountCapabilityControllerType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: true, } func init() { diff --git a/runtime/sema/anyresource_type.go b/runtime/sema/anyresource_type.go index d1de774aa8..b19d72040f 100644 --- a/runtime/sema/anyresource_type.go +++ b/runtime/sema/anyresource_type.go @@ -30,6 +30,7 @@ var AnyResourceType = &SimpleType{ Equatable: false, Comparable: false, // The actual returnability of a value is checked at run-time - Exportable: true, - Importable: false, + Exportable: true, + Importable: false, + ContainFields: true, } diff --git a/runtime/sema/anystruct_type.go b/runtime/sema/anystruct_type.go index 14548e0eeb..edbc5e8220 100644 --- a/runtime/sema/anystruct_type.go +++ b/runtime/sema/anystruct_type.go @@ -31,7 +31,8 @@ var AnyStructType = &SimpleType{ Comparable: false, Exportable: true, // The actual importability is checked at runtime - Importable: true, + Importable: true, + ContainFields: true, } var AnyStructTypeAnnotation = NewTypeAnnotation(AnyStructType) diff --git a/runtime/sema/block.cdc b/runtime/sema/block.cdc index df8562d954..97515f8f3c 100644 --- a/runtime/sema/block.cdc +++ b/runtime/sema/block.cdc @@ -1,5 +1,5 @@ -access(all) struct Block { +access(all) struct Block: ContainFields { /// The height of the block. /// diff --git a/runtime/sema/block.gen.go b/runtime/sema/block.gen.go index 0e1b5d11cb..3bdd5a712e 100644 --- a/runtime/sema/block.gen.go +++ b/runtime/sema/block.gen.go @@ -76,6 +76,7 @@ var BlockType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: true, } func init() { diff --git a/runtime/sema/character.gen.go b/runtime/sema/character.gen.go index 14ca517ea7..fdaedb6da2 100644 --- a/runtime/sema/character.gen.go +++ b/runtime/sema/character.gen.go @@ -56,6 +56,7 @@ var CharacterType = &SimpleType{ Comparable: true, Exportable: true, Importable: true, + ContainFields: false, } func init() { diff --git a/runtime/sema/check_assignment.go b/runtime/sema/check_assignment.go index 8d278a9207..daaa23613d 100644 --- a/runtime/sema/check_assignment.go +++ b/runtime/sema/check_assignment.go @@ -312,27 +312,33 @@ func (checker *Checker) visitIdentifierExpressionAssignment( return variable.Type } +var mutableEntitledAccess = NewEntitlementSetAccess( + []*EntitlementType{MutateEntitlement}, + Disjunction, +) + +var insertableAndRemovableEntitledAccess = NewEntitlementSetAccess( + []*EntitlementType{InsertEntitlement, RemoveEntitlement}, + Conjunction, +) + func (checker *Checker) visitIndexExpressionAssignment( indexExpression *ast.IndexExpression, ) (elementType Type) { elementType = checker.visitIndexExpression(indexExpression, true) - if targetExpression, ok := indexExpression.TargetExpression.(*ast.MemberExpression); ok { - // visitMember caches its result, so visiting the target expression again, - // after it had been previously visited by visiting the outer index expression, - // performs no computation - _, _, member, _ := checker.visitMember(targetExpression) - if member != nil && !checker.isMutatableMember(member) { - checker.report( - &ExternalMutationError{ - Name: member.Identifier.Identifier, - DeclarationKind: member.DeclarationKind, - Range: ast.NewRangeFromPositioned(checker.memoryGauge, targetExpression), - ContainerType: member.ContainerType, - }, - ) - } + indexExprTypes := checker.Elaboration.IndexExpressionTypes(indexExpression) + indexedRefType, isReference := referenceType(indexExprTypes.IndexedType) + + if isReference && + !mutableEntitledAccess.PermitsAccess(indexedRefType.Authorization) && + !insertableAndRemovableEntitledAccess.PermitsAccess(indexedRefType.Authorization) { + checker.report(&UnauthorizedReferenceAssignmentError{ + RequiredAccess: [2]Access{mutableEntitledAccess, insertableAndRemovableEntitledAccess}, + FoundAccess: indexedRefType.Authorization, + Range: ast.NewRangeFromPositioned(checker.memoryGauge, indexExpression), + }) } if elementType == nil { diff --git a/runtime/sema/check_expression.go b/runtime/sema/check_expression.go index a3300156c3..cad1040e2b 100644 --- a/runtime/sema/check_expression.go +++ b/runtime/sema/check_expression.go @@ -316,11 +316,26 @@ func (checker *Checker) visitIndexExpression( checker.checkUnusedExpressionResourceLoss(elementType, targetExpression) + // If the element, + // 1) is accessed via a reference, and + // 2) is container-typed, + // then the element type should also be a reference. + returnReference := false + if !isAssignment && shouldReturnReference(valueIndexedType, elementType) { + // For index expressions, element are un-authorized. + elementType = checker.getReferenceType(elementType, false, UnauthorizedAccess) + + // Store the result in elaboration, so the interpreter can re-use this. + returnReference = true + } + checker.Elaboration.SetIndexExpressionTypes( indexExpression, IndexExpressionTypes{ - IndexedType: valueIndexedType, - IndexingType: indexingType, + IndexedType: valueIndexedType, + IndexingType: indexingType, + ResultType: elementType, + ReturnReference: returnReference, }, ) diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 8063abfad7..582d5d420c 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -83,30 +83,65 @@ func (checker *Checker) VisitMemberExpression(expression *ast.MemberExpression) // If the member access is optional chaining, only wrap the result value // in an optional, if it is not already an optional value - if isOptional { if _, ok := memberType.(*OptionalType); !ok { - return &OptionalType{Type: memberType} + memberType = NewOptionalType(checker.memoryGauge, memberType) } } return memberType } +// getReferenceType Returns a reference type to a given type. +// Reference to an optional should return an optional reference. +// This has to be done recursively for nested optionals. +// e.g.1: Given type T, this method returns &T. +// e.g.2: Given T?, this returns (&T)? +func (checker *Checker) getReferenceType(typ Type, substituteAuthorization bool, authorization Access) Type { + if optionalType, ok := typ.(*OptionalType); ok { + innerType := checker.getReferenceType(optionalType.Type, substituteAuthorization, authorization) + return NewOptionalType(checker.memoryGauge, innerType) + } + + auth := UnauthorizedAccess + if substituteAuthorization && authorization != nil { + auth = authorization + } + + return NewReferenceType(checker.memoryGauge, typ, auth) +} + +func shouldReturnReference(parentType, memberType Type) bool { + if _, isReference := referenceType(parentType); !isReference { + return false + } + + return memberType.ContainFieldsOrElements() +} + +func referenceType(typ Type) (*ReferenceType, bool) { + unwrappedType := UnwrapOptionalType(typ) + refType, isReference := unwrappedType.(*ReferenceType) + return refType, isReference +} + func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedType Type, resultingType Type, member *Member, isOptional bool) { - memberInfo, ok := checker.Elaboration.MemberExpressionMemberInfo(expression) + memberInfo, ok := checker.Elaboration.MemberExpressionMemberAccessInfo(expression) if ok { return memberInfo.AccessedType, memberInfo.ResultingType, memberInfo.Member, memberInfo.IsOptional } + returnReference := false + defer func() { - checker.Elaboration.SetMemberExpressionMemberInfo( + checker.Elaboration.SetMemberExpressionMemberAccessInfo( expression, - MemberInfo{ - AccessedType: accessedType, - ResultingType: resultingType, - Member: member, - IsOptional: isOptional, + MemberAccessInfo{ + AccessedType: accessedType, + ResultingType: resultingType, + Member: member, + IsOptional: isOptional, + ReturnReference: returnReference, }, ) }() @@ -166,24 +201,6 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT targetRange := ast.NewRangeFromPositioned(checker.memoryGauge, expression.Expression) member = resolver.Resolve(checker.memoryGauge, identifier, targetRange, checker.report) resultingType = member.TypeAnnotation.Type - if resolver.Mutating { - if targetExpression, ok := accessedExpression.(*ast.MemberExpression); ok { - // visitMember caches its result, so visiting the target expression again, - // after it had been previously visited to get the resolver, - // performs no computation - _, _, subMember, _ := checker.visitMember(targetExpression) - if subMember != nil && !checker.isMutatableMember(subMember) { - checker.report( - &ExternalMutationError{ - Name: subMember.Identifier.Identifier, - DeclarationKind: subMember.DeclarationKind, - Range: ast.NewRangeFromPositioned(checker.memoryGauge, targetRange), - ContainerType: subMember.ContainerType, - }, - ) - } - } - } } // Get the member from the accessed value based @@ -249,111 +266,115 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT }, ) } - } else { - if checker.PositionInfo != nil { - checker.PositionInfo.recordMemberOccurrence( - accessedType, - identifier, - identifierStartPosition, - identifierEndPosition, - ) - } + return + } - // Check access and report if inaccessible - accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) } - isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, accessRange) - if !isReadable { - checker.report( - &InvalidAccessError{ - Name: member.Identifier.Identifier, - RestrictingAccess: member.Access, - DeclarationKind: member.DeclarationKind, - Range: accessRange(), - }, - ) - } + if checker.PositionInfo != nil { + checker.PositionInfo.recordMemberOccurrence( + accessedType, + identifier, + identifierStartPosition, + identifierEndPosition, + ) + } - // the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type - // i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`, - // we substitute this entitlement in for the "variable" `M` to produce `auth(E) &T?`, the access with which the type is actually produced. - // Equivalently, this can be thought of like generic instantiation. - substituteConcreteAuthorization := func(resultingType Type) Type { - switch ty := resultingType.(type) { + // Check access and report if inaccessible + accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) } + isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, resultingType, accessRange) + if !isReadable { + checker.report( + &InvalidAccessError{ + Name: member.Identifier.Identifier, + RestrictingAccess: member.Access, + DeclarationKind: member.DeclarationKind, + Range: accessRange(), + }, + ) + } + + // the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type + // i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`, + // we substitute this entitlement in for the "variable" `M` to produce `auth(E) &T?`, the access with which the type is actually produced. + // Equivalently, this can be thought of like generic instantiation. + substituteConcreteAuthorization := func(resultingType Type) Type { + switch ty := resultingType.(type) { + case *ReferenceType: + return NewReferenceType(checker.memoryGauge, ty.Type, resultingAuthorization) + case *OptionalType: + switch innerTy := ty.Type.(type) { case *ReferenceType: - return NewReferenceType(checker.memoryGauge, ty.Type, resultingAuthorization) - case *OptionalType: - switch innerTy := ty.Type.(type) { - case *ReferenceType: - return NewOptionalType(checker.memoryGauge, - NewReferenceType(checker.memoryGauge, innerTy.Type, resultingAuthorization)) - } + return NewOptionalType(checker.memoryGauge, + NewReferenceType(checker.memoryGauge, innerTy.Type, resultingAuthorization)) } - return resultingType } - if !member.Access.Equal(resultingAuthorization) { - switch ty := resultingType.(type) { - case *FunctionType: - resultingType = NewSimpleFunctionType( - ty.Purity, - ty.Parameters, - NewTypeAnnotation(substituteConcreteAuthorization(ty.ReturnTypeAnnotation.Type)), - ) - default: - resultingType = substituteConcreteAuthorization(resultingType) - } - } - - // Check that the member access is not to a function of resource type - // outside of an invocation of it. - // - // This would result in a bound method for a resource, which is invalid. + return resultingType + } - if !checker.inAssignment && - !checker.inInvocation && - member.DeclarationKind == common.DeclarationKindFunction && - !accessedType.IsInvalidType() && - accessedType.IsResourceType() { + shouldSubstituteAuthorization := !member.Access.Equal(resultingAuthorization) - checker.report( - &ResourceMethodBindingError{ - Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression), - }, + if shouldSubstituteAuthorization { + switch ty := resultingType.(type) { + case *FunctionType: + resultingType = NewSimpleFunctionType( + ty.Purity, + ty.Parameters, + NewTypeAnnotation(substituteConcreteAuthorization(ty.ReturnTypeAnnotation.Type)), ) + default: + resultingType = substituteConcreteAuthorization(resultingType) } } + + // Check that the member access is not to a function of resource type + // outside of an invocation of it. + // + // This would result in a bound method for a resource, which is invalid. + + if !checker.inAssignment && + !checker.inInvocation && + member.DeclarationKind == common.DeclarationKindFunction && + !accessedType.IsInvalidType() && + accessedType.IsResourceType() { + + checker.report( + &ResourceMethodBindingError{ + Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression), + }, + ) + } + + // If the member, + // 1) is accessed via a reference, and + // 2) is container-typed, + // then the member type should also be a reference. + + // Note: For attachments, `self` is always a reference. + // But we do not want to return a reference for `self.something`. + // Otherwise, things like `destroy self.something` would become invalid. + // Hence, special case `self`, and return a reference only if the member is not accessed via self. + // i.e: `accessedSelfMember == nil` + + if accessedSelfMember == nil && + shouldReturnReference(accessedType, resultingType) && + member.DeclarationKind == common.DeclarationKindField { + + // Get a reference to the type + resultingType = checker.getReferenceType(resultingType, shouldSubstituteAuthorization, resultingAuthorization) + returnReference = true + } + return accessedType, resultingType, member, isOptional } // isReadableMember returns true if the given member can be read from // in the current location of the checker, along with the authorzation with which the result can be used -func (checker *Checker) isReadableMember(accessedType Type, member *Member, accessRange func() ast.Range) (bool, Access) { - var mapAccess func(EntitlementMapAccess, Type) (bool, Access) - mapAccess = func(mappedAccess EntitlementMapAccess, accessedType Type) (bool, Access) { - switch ty := accessedType.(type) { - case *ReferenceType: - // when accessing a member on a reference, the read is allowed, but the - // granted entitlements are based on the image through the map of the reference's entitlements - grantedAccess, err := mappedAccess.Image(ty.Authorization, accessRange) - if err != nil { - checker.report(err) - return false, member.Access - } - return true, grantedAccess - case *OptionalType: - return mapAccess(mappedAccess, ty.Type) - default: - // when accessing a member on a non-reference, the resulting mapped entitlement - // should be the entire codomain of the map - return true, mappedAccess.Codomain() - } - } - +func (checker *Checker) isReadableMember(accessedType Type, member *Member, resultingType Type, accessRange func() ast.Range) (bool, Access) { if checker.Config.AccessCheckMode.IsReadableAccess(member.Access) || checker.containerTypes[member.ContainerType] { if mappedAccess, isMappedAccess := member.Access.(EntitlementMapAccess); isMappedAccess { - return mapAccess(mappedAccess, accessedType) + return checker.mapAccess(mappedAccess, accessedType, resultingType, accessRange) } return true, member.Access @@ -397,12 +418,81 @@ func (checker *Checker) isReadableMember(accessedType Type, member *Member, acce return true, member.Access } case EntitlementMapAccess: - return mapAccess(access, accessedType) + return checker.mapAccess(access, accessedType, resultingType, accessRange) } return false, member.Access } +func (checker *Checker) mapAccess( + mappedAccess EntitlementMapAccess, + accessedType Type, + resultingType Type, + accessRange func() ast.Range, +) (bool, Access) { + + switch ty := accessedType.(type) { + case *ReferenceType: + // when accessing a member on a reference, the read is allowed, but the + // granted entitlements are based on the image through the map of the reference's entitlements + grantedAccess, err := mappedAccess.Image(ty.Authorization, accessRange) + if err != nil { + checker.report(err) + return false, mappedAccess + } + return true, grantedAccess + + case *OptionalType: + return checker.mapAccess(mappedAccess, ty.Type, resultingType, accessRange) + + default: + if mappedAccess.Type == IdentityMappingType { + access := AllSupportedEntitlements(resultingType) + if access != nil { + return true, access + } + } + + // when accessing a member on a non-reference, the resulting mapped entitlement + // should be the entire codomain of the map + return true, mappedAccess.Codomain() + } +} + +func AllSupportedEntitlements(typ Type) Access { + return allSupportedEntitlements(typ, false) +} + +func allSupportedEntitlements(typ Type, isInnerType bool) Access { + switch typ := typ.(type) { + case *ReferenceType: + return allSupportedEntitlements(typ.Type, true) + case *OptionalType: + return allSupportedEntitlements(typ.Type, true) + case *FunctionType: + // Entitlements must be returned only for function definitions. + // Other than func-definitions, a member can be a function type in two ways: + // 1) Function-typed field - Mappings are not allowed on function typed fields + // 2) Function reference typed field - A function type inside a reference/optional-reference + // (i.e: an inner function type) should not be considered for entitlements. + // + if !isInnerType { + return allSupportedEntitlements(typ.ReturnTypeAnnotation.Type, true) + } + case EntitlementSupportingType: + supportedEntitlements := typ.SupportedEntitlements() + if supportedEntitlements != nil && supportedEntitlements.Len() > 0 { + access := EntitlementSetAccess{ + SetKind: Conjunction, + Entitlements: supportedEntitlements, + } + return access + } + } + + return nil +} + // isWriteableMember returns true if the given member can be written to // in the current location of the checker func (checker *Checker) isWriteableMember(member *Member) bool { @@ -410,13 +500,6 @@ func (checker *Checker) isWriteableMember(member *Member) bool { checker.containerTypes[member.ContainerType] } -// isMutatableMember returns true if the given member can be mutated -// in the current location of the checker. Currently equivalent to -// isWriteableMember above, but separate in case this changes -func (checker *Checker) isMutatableMember(member *Member) bool { - return checker.isWriteableMember(member) -} - // containingContractKindedType returns the containing contract-kinded type // of the given type, if any. // diff --git a/runtime/sema/check_swap.go b/runtime/sema/check_swap.go index 2724a7e8f1..da30e6df50 100644 --- a/runtime/sema/check_swap.go +++ b/runtime/sema/check_swap.go @@ -25,38 +25,28 @@ import ( func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) (_ struct{}) { - leftType := checker.VisitExpression(swap.Left, nil) - rightType := checker.VisitExpression(swap.Right, nil) + // First visit the two expressions as if they were the target of the assignment. + leftTargetType := checker.checkSwapStatementExpression(swap.Left, common.OperandSideLeft) + rightTargetType := checker.checkSwapStatementExpression(swap.Right, common.OperandSideRight) + + // Then re-visit the same expressions, this time treat them as the value-expr of the assignment. + // The 'expected type' of the two expression would be the types obtained from the previous visit, swapped. + leftValueType := checker.VisitExpression(swap.Left, rightTargetType) + rightValueType := checker.VisitExpression(swap.Right, leftTargetType) checker.Elaboration.SetSwapStatementTypes( swap, SwapStatementTypes{ - LeftType: leftType, - RightType: rightType, + LeftType: leftValueType, + RightType: rightValueType, }, ) - lhsValid := checker.checkSwapStatementExpression(swap.Left, leftType, common.OperandSideLeft) - rhsValid := checker.checkSwapStatementExpression(swap.Right, rightType, common.OperandSideRight) - - // The types of both sides must be subtypes of each other, - // so that assignment can be performed in both directions. - // i.e: The two types have to be equal. - if lhsValid && rhsValid && !leftType.Equal(rightType) { - checker.report( - &TypeMismatchError{ - ExpectedType: leftType, - ActualType: rightType, - Range: ast.NewRangeFromPositioned(checker.memoryGauge, swap.Right), - }, - ) - } - - if leftType.IsResourceType() { + if leftValueType.IsResourceType() { checker.elaborateNestedResourceMoveExpression(swap.Left) } - if rightType.IsResourceType() { + if rightValueType.IsResourceType() { checker.elaborateNestedResourceMoveExpression(swap.Right) } @@ -65,9 +55,8 @@ func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) (_ struct{}) func (checker *Checker) checkSwapStatementExpression( expression ast.Expression, - exprType Type, opSide common.OperandSide, -) bool { +) Type { // Expression in either side of the swap statement must be a target expression. // (e.g. identifier expression, indexing expression, or member access expression) @@ -78,13 +67,8 @@ func (checker *Checker) checkSwapStatementExpression( Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression), }, ) - return false - } - - if exprType.IsInvalidType() { - return false + return InvalidType } - checker.visitAssignmentValueType(expression) - return true + return checker.visitAssignmentValueType(expression) } diff --git a/runtime/sema/check_variable_declaration.go b/runtime/sema/check_variable_declaration.go index 8dee560d51..971abbbe5d 100644 --- a/runtime/sema/check_variable_declaration.go +++ b/runtime/sema/check_variable_declaration.go @@ -264,8 +264,7 @@ func (checker *Checker) recordReference(targetVariable *Variable, expr ast.Expre return } - unwrappedVarType := UnwrapOptionalType(targetVariable.Type) - if _, isReferenceType := unwrappedVarType.(*ReferenceType); !isReferenceType { + if _, isReference := referenceType(targetVariable.Type); !isReference { return } @@ -287,6 +286,16 @@ func (checker *Checker) referencedVariables(expr ast.Expression) (variables []*V variableRefExpr = rootVariableOfExpression(refExpr.Expression) case *ast.IdentifierExpression: variableRefExpr = &refExpr.Identifier + case *ast.IndexExpression: + // If it is a reference expression, then find the "root variable". + // As nested resources cannot be tracked, at least track the "root" if possible. + // For example, for an expression `a[b][c]`, the "root variable" is `a`. + variableRefExpr = rootVariableOfExpression(refExpr.TargetExpression) + case *ast.MemberExpression: + // If it is a reference expression, then find the "root variable". + // As nested resources cannot be tracked, at least track the "root" if possible. + // For example, for an expression `a.b.c`, the "root variable" is `a`. + variableRefExpr = rootVariableOfExpression(refExpr.Expression) default: continue } @@ -373,7 +382,11 @@ func referenceExpressions(expr ast.Expression) []ast.Expression { } return refExpressions - case *ast.IdentifierExpression: + case *ast.IdentifierExpression, + *ast.IndexExpression, + *ast.MemberExpression: + // For all these expressions, we reach here only if the expression's type is a reference. + // Hence, no need to check it here again. return []ast.Expression{expr} default: return nil diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 5dcf2b8d7c..4965ef53f2 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -1718,7 +1718,6 @@ func (checker *Checker) checkDeclarationAccessModifier( isConstant bool, ) { if checker.functionActivations.IsLocal() { - if !access.Equal(PrimitiveAccess(ast.AccessNotSpecified)) { checker.report( &InvalidAccessModifierError{ @@ -1729,158 +1728,189 @@ func (checker *Checker) checkDeclarationAccessModifier( }, ) } - } else { + return + } - isTypeDeclaration := declarationKind.IsTypeDeclaration() + switch access := access.(type) { + case PrimitiveAccess: + checker.checkPrimitiveAccess(access, isConstant, declarationKind, startPos) + case EntitlementMapAccess: + checker.checkEntitlementMapAccess(access, declarationKind, declarationType, containerKind, startPos) + case EntitlementSetAccess: + checker.checkEntitlementSetAccess(declarationType, containerKind, startPos) + } +} - switch access := access.(type) { - case PrimitiveAccess: - switch ast.PrimitiveAccess(access) { - case ast.AccessSelf: - // Type declarations must be public for now +func (checker *Checker) checkPrimitiveAccess( + access PrimitiveAccess, + isConstant bool, + declarationKind common.DeclarationKind, + startPos ast.Position, +) { - if isTypeDeclaration { + isTypeDeclaration := declarationKind.IsTypeDeclaration() - checker.report( - &InvalidAccessModifierError{ - Access: access, - Explanation: invalidTypeDeclarationAccessModifierExplanation, - DeclarationKind: declarationKind, - Pos: startPos, - }, - ) - } + switch ast.PrimitiveAccess(access) { + case ast.AccessSelf: + // Type declarations must be public for now - case ast.AccessContract, - ast.AccessAccount: + if isTypeDeclaration { + + checker.report( + &InvalidAccessModifierError{ + Access: access, + Explanation: invalidTypeDeclarationAccessModifierExplanation, + DeclarationKind: declarationKind, + Pos: startPos, + }, + ) + } - // Type declarations must be public for now + case ast.AccessContract, + ast.AccessAccount: - if isTypeDeclaration { - checker.report( - &InvalidAccessModifierError{ - Access: access, - Explanation: invalidTypeDeclarationAccessModifierExplanation, - DeclarationKind: declarationKind, - Pos: startPos, - }, - ) - } + // Type declarations must be public for now - case ast.AccessNotSpecified: + if isTypeDeclaration { + checker.report( + &InvalidAccessModifierError{ + Access: access, + Explanation: invalidTypeDeclarationAccessModifierExplanation, + DeclarationKind: declarationKind, + Pos: startPos, + }, + ) + } - // Type declarations cannot be effectively private for now + case ast.AccessNotSpecified: - if isTypeDeclaration && - checker.Config.AccessCheckMode == AccessCheckModeNotSpecifiedRestricted { + // Type declarations cannot be effectively private for now - checker.report( - &MissingAccessModifierError{ - DeclarationKind: declarationKind, - Explanation: invalidTypeDeclarationAccessModifierExplanation, - Pos: startPos, - }, - ) - } + if isTypeDeclaration && + checker.Config.AccessCheckMode == AccessCheckModeNotSpecifiedRestricted { - // In strict mode, access modifiers must be given + checker.report( + &MissingAccessModifierError{ + DeclarationKind: declarationKind, + Explanation: invalidTypeDeclarationAccessModifierExplanation, + Pos: startPos, + }, + ) + } - if checker.Config.AccessCheckMode == AccessCheckModeStrict { - checker.report( - &MissingAccessModifierError{ - DeclarationKind: declarationKind, - Pos: startPos, - }, - ) - } - } + // In strict mode, access modifiers must be given - case EntitlementMapAccess: - // attachments may be declared with an entitlement map access - if declarationKind == common.DeclarationKindAttachment { - return - } + if checker.Config.AccessCheckMode == AccessCheckModeStrict { + checker.report( + &MissingAccessModifierError{ + DeclarationKind: declarationKind, + Pos: startPos, + }, + ) + } + } +} - // otherwise, mapped entitlements may only be used in structs and resources - if containerKind == nil || - (*containerKind != common.CompositeKindResource && - *containerKind != common.CompositeKindStructure) { - checker.report( - &InvalidMappedEntitlementMemberError{ - Pos: startPos, - }, - ) +func (checker *Checker) checkEntitlementMapAccess( + access EntitlementMapAccess, + declarationKind common.DeclarationKind, + declarationType Type, + containerKind *common.CompositeKind, + startPos ast.Position, +) { + // attachments may be declared with an entitlement map access + if declarationKind == common.DeclarationKindAttachment { + return + } + + // otherwise, mapped entitlements may only be used in structs, resources and attachments + if containerKind == nil || + (*containerKind != common.CompositeKindResource && + *containerKind != common.CompositeKindStructure && + *containerKind != common.CompositeKindAttachment) { + checker.report( + &InvalidMappedEntitlementMemberError{ + Pos: startPos, + }, + ) + return + } + + // mapped entitlement fields must be, one of: + // 1) An [optional] reference that is authorized to the same mapped entitlement. + // 2) A function that return an [optional] reference authorized to the same mapped entitlement. + // 3) A container - So if the parent is a reference, entitlements can be granted to the resulting field reference. + + entitledType := declarationType + + if functionType, isFunction := declarationType.(*FunctionType); isFunction { + if declarationKind == common.DeclarationKindFunction { + entitledType = functionType.ReturnTypeAnnotation.Type + } + } + + switch ty := entitledType.(type) { + case *ReferenceType: + if ty.Authorization.Equal(access) { + return + } + case *OptionalType: + switch optionalType := ty.Type.(type) { + case *ReferenceType: + if optionalType.Authorization.Equal(access) { return } + } + default: + // Also allow entitlement mappings for container-typed fields + if declarationType.ContainFieldsOrElements() { + return + } + } - // mapped entitlement fields must be (optional) references that are authorized to the same mapped entitlement, - // or functions that return an (optional) reference authorized to the same mapped entitlement - requireIsPotentiallyOptionalReference := func(typ Type) { - switch ty := typ.(type) { - case *ReferenceType: - if ty.Authorization.Equal(access) { - return - } - case *OptionalType: - switch optionalType := ty.Type.(type) { - case *ReferenceType: - if optionalType.Authorization.Equal(access) { - return - } - } - } - checker.report( - &InvalidMappedEntitlementMemberError{ - Pos: startPos, - }, - ) - } + checker.report( + &InvalidMappedEntitlementMemberError{ + Pos: startPos, + }, + ) +} - switch ty := declarationType.(type) { - case *FunctionType: - if declarationKind == common.DeclarationKindFunction { - requireIsPotentiallyOptionalReference(ty.ReturnTypeAnnotation.Type) - } else { - requireIsPotentiallyOptionalReference(ty) - } - default: - requireIsPotentiallyOptionalReference(ty) - } +func (checker *Checker) checkEntitlementSetAccess( + declarationType Type, + containerKind *common.CompositeKind, + startPos ast.Position, +) { + if containerKind == nil || + (*containerKind != common.CompositeKindResource && + *containerKind != common.CompositeKindStructure && + *containerKind != common.CompositeKindAttachment) { + checker.report( + &InvalidEntitlementAccessError{ + Pos: startPos, + }, + ) + return + } - case EntitlementSetAccess: - if containerKind == nil || - (*containerKind != common.CompositeKindResource && - *containerKind != common.CompositeKindStructure && - *containerKind != common.CompositeKindAttachment) { + // when using entitlement set access, it is not permitted for the value to be declared with a mapped entitlement + switch ty := declarationType.(type) { + case *ReferenceType: + if _, isMap := ty.Authorization.(EntitlementMapAccess); isMap { + checker.report( + &InvalidMappedEntitlementMemberError{ + Pos: startPos, + }, + ) + } + case *OptionalType: + switch optionalType := ty.Type.(type) { + case *ReferenceType: + if _, isMap := optionalType.Authorization.(EntitlementMapAccess); isMap { checker.report( - &InvalidEntitlementAccessError{ + &InvalidMappedEntitlementMemberError{ Pos: startPos, }, ) - return - } - - // when using entitlement set access, it is not permitted for the value to be declared with a mapped entitlement - switch ty := declarationType.(type) { - case *ReferenceType: - if _, isMap := ty.Authorization.(EntitlementMapAccess); isMap { - checker.report( - &InvalidMappedEntitlementMemberError{ - Pos: startPos, - }, - ) - } - case *OptionalType: - switch optionalType := ty.Type.(type) { - case *ReferenceType: - if _, isMap := optionalType.Authorization.(EntitlementMapAccess); isMap { - checker.report( - &InvalidMappedEntitlementMemberError{ - Pos: startPos, - }, - ) - } - } } } } diff --git a/runtime/sema/deployedcontract.cdc b/runtime/sema/deployedcontract.cdc index 61f2195aa8..11e611bfbc 100644 --- a/runtime/sema/deployedcontract.cdc +++ b/runtime/sema/deployedcontract.cdc @@ -1,5 +1,5 @@ -access(all) struct DeployedContract { +access(all) struct DeployedContract: ContainFields { /// The address of the account where the contract is deployed at. access(all) let address: Address diff --git a/runtime/sema/deployedcontract.gen.go b/runtime/sema/deployedcontract.gen.go index 87b3883b84..621b08d473 100644 --- a/runtime/sema/deployedcontract.gen.go +++ b/runtime/sema/deployedcontract.gen.go @@ -84,6 +84,7 @@ var DeployedContractType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: true, } func init() { diff --git a/runtime/sema/elaboration.go b/runtime/sema/elaboration.go index f0aba25c08..43350d744a 100644 --- a/runtime/sema/elaboration.go +++ b/runtime/sema/elaboration.go @@ -25,11 +25,12 @@ import ( "github.com/onflow/cadence/runtime/common" ) -type MemberInfo struct { - AccessedType Type - ResultingType Type - Member *Member - IsOptional bool +type MemberAccessInfo struct { + AccessedType Type + ResultingType Type + Member *Member + IsOptional bool + ReturnReference bool } type CastTypes struct { @@ -88,8 +89,10 @@ type SwapStatementTypes struct { } type IndexExpressionTypes struct { - IndexedType ValueIndexableType - IndexingType Type + IndexedType ValueIndexableType + IndexingType Type + ResultType Type + ReturnReference bool } type NumberConversionArgumentTypes struct { @@ -108,33 +111,33 @@ type ExpressionTypes struct { } type Elaboration struct { - fixedPointExpressionTypes map[*ast.FixedPointExpression]Type - interfaceTypeDeclarations map[*InterfaceType]*ast.InterfaceDeclaration - entitlementTypeDeclarations map[*EntitlementType]*ast.EntitlementDeclaration - entitlementMapTypeDeclarations map[*EntitlementMapType]*ast.EntitlementMappingDeclaration - swapStatementTypes map[*ast.SwapStatement]SwapStatementTypes - assignmentStatementTypes map[*ast.AssignmentStatement]AssignmentStatementTypes - compositeDeclarationTypes map[ast.CompositeLikeDeclaration]*CompositeType - compositeTypeDeclarations map[*CompositeType]ast.CompositeLikeDeclaration - interfaceDeclarationTypes map[*ast.InterfaceDeclaration]*InterfaceType - entitlementDeclarationTypes map[*ast.EntitlementDeclaration]*EntitlementType - entitlementMapDeclarationTypes map[*ast.EntitlementMappingDeclaration]*EntitlementMapType - transactionDeclarationTypes map[*ast.TransactionDeclaration]*TransactionType - constructorFunctionTypes map[*ast.SpecialFunctionDeclaration]*FunctionType - functionExpressionFunctionTypes map[*ast.FunctionExpression]*FunctionType - invocationExpressionTypes map[*ast.InvocationExpression]InvocationExpressionTypes - castingExpressionTypes map[*ast.CastingExpression]CastingExpressionTypes - lock *sync.RWMutex - binaryExpressionTypes map[*ast.BinaryExpression]BinaryExpressionTypes - memberExpressionMemberInfos map[*ast.MemberExpression]MemberInfo - memberExpressionExpectedTypes map[*ast.MemberExpression]Type - arrayExpressionTypes map[*ast.ArrayExpression]ArrayExpressionTypes - dictionaryExpressionTypes map[*ast.DictionaryExpression]DictionaryExpressionTypes - integerExpressionTypes map[*ast.IntegerExpression]Type - stringExpressionTypes map[*ast.StringExpression]Type - returnStatementTypes map[*ast.ReturnStatement]ReturnStatementTypes - functionDeclarationFunctionTypes map[*ast.FunctionDeclaration]*FunctionType - variableDeclarationTypes map[*ast.VariableDeclaration]VariableDeclarationTypes + fixedPointExpressionTypes map[*ast.FixedPointExpression]Type + interfaceTypeDeclarations map[*InterfaceType]*ast.InterfaceDeclaration + entitlementTypeDeclarations map[*EntitlementType]*ast.EntitlementDeclaration + entitlementMapTypeDeclarations map[*EntitlementMapType]*ast.EntitlementMappingDeclaration + swapStatementTypes map[*ast.SwapStatement]SwapStatementTypes + assignmentStatementTypes map[*ast.AssignmentStatement]AssignmentStatementTypes + compositeDeclarationTypes map[ast.CompositeLikeDeclaration]*CompositeType + compositeTypeDeclarations map[*CompositeType]ast.CompositeLikeDeclaration + interfaceDeclarationTypes map[*ast.InterfaceDeclaration]*InterfaceType + entitlementDeclarationTypes map[*ast.EntitlementDeclaration]*EntitlementType + entitlementMapDeclarationTypes map[*ast.EntitlementMappingDeclaration]*EntitlementMapType + transactionDeclarationTypes map[*ast.TransactionDeclaration]*TransactionType + constructorFunctionTypes map[*ast.SpecialFunctionDeclaration]*FunctionType + functionExpressionFunctionTypes map[*ast.FunctionExpression]*FunctionType + invocationExpressionTypes map[*ast.InvocationExpression]InvocationExpressionTypes + castingExpressionTypes map[*ast.CastingExpression]CastingExpressionTypes + lock *sync.RWMutex + binaryExpressionTypes map[*ast.BinaryExpression]BinaryExpressionTypes + memberExpressionMemberAccessInfos map[*ast.MemberExpression]MemberAccessInfo + memberExpressionExpectedTypes map[*ast.MemberExpression]Type + arrayExpressionTypes map[*ast.ArrayExpression]ArrayExpressionTypes + dictionaryExpressionTypes map[*ast.DictionaryExpression]DictionaryExpressionTypes + integerExpressionTypes map[*ast.IntegerExpression]Type + stringExpressionTypes map[*ast.StringExpression]Type + returnStatementTypes map[*ast.ReturnStatement]ReturnStatementTypes + functionDeclarationFunctionTypes map[*ast.FunctionDeclaration]*FunctionType + variableDeclarationTypes map[*ast.VariableDeclaration]VariableDeclarationTypes // nestedResourceMoveExpressions indicates the index or member expression // is implicitly moving a resource out of the container, e.g. in a shift or swap statement. nestedResourceMoveExpressions map[ast.Expression]struct{} @@ -635,20 +638,20 @@ func (e *Elaboration) SetIntegerExpressionType(expression *ast.IntegerExpression e.integerExpressionTypes[expression] = actualType } -func (e *Elaboration) MemberExpressionMemberInfo(expression *ast.MemberExpression) (memberInfo MemberInfo, ok bool) { - if e.memberExpressionMemberInfos == nil { +func (e *Elaboration) MemberExpressionMemberAccessInfo(expression *ast.MemberExpression) (memberInfo MemberAccessInfo, ok bool) { + if e.memberExpressionMemberAccessInfos == nil { ok = false return } - memberInfo, ok = e.memberExpressionMemberInfos[expression] + memberInfo, ok = e.memberExpressionMemberAccessInfos[expression] return } -func (e *Elaboration) SetMemberExpressionMemberInfo(expression *ast.MemberExpression, memberInfo MemberInfo) { - if e.memberExpressionMemberInfos == nil { - e.memberExpressionMemberInfos = map[*ast.MemberExpression]MemberInfo{} +func (e *Elaboration) SetMemberExpressionMemberAccessInfo(expression *ast.MemberExpression, memberAccessInfo MemberAccessInfo) { + if e.memberExpressionMemberAccessInfos == nil { + e.memberExpressionMemberAccessInfos = map[*ast.MemberExpression]MemberAccessInfo{} } - e.memberExpressionMemberInfos[expression] = memberInfo + e.memberExpressionMemberAccessInfos[expression] = memberAccessInfo } func (e *Elaboration) MemberExpressionExpectedType(expression *ast.MemberExpression) Type { diff --git a/runtime/sema/entitlements.cdc b/runtime/sema/entitlements.cdc new file mode 100644 index 0000000000..07c447eac7 --- /dev/null +++ b/runtime/sema/entitlements.cdc @@ -0,0 +1,6 @@ + +entitlement Mutate + +entitlement Insert + +entitlement Remove diff --git a/runtime/sema/entitlements.gen.go b/runtime/sema/entitlements.gen.go new file mode 100644 index 0000000000..55e7316f51 --- /dev/null +++ b/runtime/sema/entitlements.gen.go @@ -0,0 +1,41 @@ +// Code generated from entitlements.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +var MutateEntitlement = &EntitlementType{ + Identifier: "Mutate", +} + +var InsertEntitlement = &EntitlementType{ + Identifier: "Insert", +} + +var RemoveEntitlement = &EntitlementType{ + Identifier: "Remove", +} + +func init() { + BuiltinEntitlements[MutateEntitlement.Identifier] = MutateEntitlement + addToBaseActivation(MutateEntitlement) + BuiltinEntitlements[InsertEntitlement.Identifier] = InsertEntitlement + addToBaseActivation(InsertEntitlement) + BuiltinEntitlements[RemoveEntitlement.Identifier] = RemoveEntitlement + addToBaseActivation(RemoveEntitlement) +} diff --git a/runtime/sema/entitlements.go b/runtime/sema/entitlements.go new file mode 100644 index 0000000000..d2cbeebd04 --- /dev/null +++ b/runtime/sema/entitlements.go @@ -0,0 +1,21 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +//go:generate go run ./gen entitlements.cdc entitlements.gen.go diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 0efa538b92..54fcf8b049 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -3077,6 +3077,46 @@ func (e *InvalidAssignmentAccessError) SecondaryError() string { ) } +// UnauthorizedReferenceAssignmentError + +type UnauthorizedReferenceAssignmentError struct { + RequiredAccess [2]Access + FoundAccess Access + ast.Range +} + +var _ SemanticError = &UnauthorizedReferenceAssignmentError{} +var _ errors.UserError = &UnauthorizedReferenceAssignmentError{} +var _ errors.SecondaryError = &UnauthorizedReferenceAssignmentError{} + +func (*UnauthorizedReferenceAssignmentError) isSemanticError() {} + +func (*UnauthorizedReferenceAssignmentError) IsUserError() {} + +func (e *UnauthorizedReferenceAssignmentError) Error() string { + var foundAccess string + if e.FoundAccess == UnauthorizedAccess { + foundAccess = "non-auth" + } else { + foundAccess = fmt.Sprintf("(%s)", e.FoundAccess.Description()) + } + + return fmt.Sprintf( + "invalid assignment: can only assign to a reference with (%s) or (%s) access, but found a %s reference", + e.RequiredAccess[0].Description(), + e.RequiredAccess[1].Description(), + foundAccess, + ) +} + +func (e *UnauthorizedReferenceAssignmentError) SecondaryError() string { + return fmt.Sprintf( + "consider taking a reference with `%s` or `%s` access", + e.RequiredAccess[0].Description(), + e.RequiredAccess[1].Description(), + ) +} + // InvalidCharacterLiteralError type InvalidCharacterLiteralError struct { @@ -4036,40 +4076,6 @@ func (e *InvalidEntryPointTypeError) Error() string { ) } -// ExternalMutationError - -type ExternalMutationError struct { - ContainerType Type - Name string - ast.Range - DeclarationKind common.DeclarationKind -} - -var _ SemanticError = &ExternalMutationError{} -var _ errors.UserError = &ExternalMutationError{} -var _ errors.SecondaryError = &ExternalMutationError{} - -func (*ExternalMutationError) isSemanticError() {} - -func (*ExternalMutationError) IsUserError() {} - -func (e *ExternalMutationError) Error() string { - return fmt.Sprintf( - "cannot mutate `%s`: %s is only mutable inside `%s`", - e.Name, - e.DeclarationKind.Name(), - e.ContainerType.QualifiedString(), - ) -} - -func (e *ExternalMutationError) SecondaryError() string { - return fmt.Sprintf( - "Consider adding a setter for `%s` to `%s`", - e.Name, - e.ContainerType.QualifiedString(), - ) -} - type PurityError struct { ast.Range } diff --git a/runtime/sema/gen/main.go b/runtime/sema/gen/main.go index 5a3666cd1c..31da9da27b 100644 --- a/runtime/sema/gen/main.go +++ b/runtime/sema/gen/main.go @@ -151,6 +151,7 @@ type typeDecl struct { exportable bool comparable bool importable bool + memberAccessible bool memberDeclarations []ast.Declaration nestedTypes []*typeDecl } @@ -423,6 +424,15 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ case "Importable": typeDecl.importable = true + + case "ContainFields": + if !canGenerateSimpleType { + panic(fmt.Errorf( + "composite types cannot be explicitly marked as having fields: %s", + g.currentTypeID(), + )) + } + typeDecl.memberAccessible = true } } @@ -563,9 +573,19 @@ func (*generator) VisitTransactionDeclaration(_ *ast.TransactionDeclaration) str panic("transaction declarations are not supported") } -func (*generator) VisitEntitlementDeclaration(_ *ast.EntitlementDeclaration) struct{} { - // TODO - panic("entitlement declarations are not supported") +func (g *generator) VisitEntitlementDeclaration(decl *ast.EntitlementDeclaration) (_ struct{}) { + entitlementName := decl.Identifier.Identifier + typeVarName := entitlementVarName(entitlementName) + typeVarDecl := entitlementTypeLiteral(entitlementName) + + g.addDecls( + goVarDecl( + typeVarName, + typeVarDecl, + ), + ) + + return } func (*generator) VisitEntitlementMappingDeclaration(_ *ast.EntitlementMappingDeclaration) struct{} { @@ -990,6 +1010,70 @@ func (g *generator) currentMemberID(memberName string) string { return b.String() } +func (g *generator) generateTypeInit(program *ast.Program) { + + // Currently this only generate registering of entitlements. + // It is possible to extend this to register other types as well. + // So they are not needed to be manually added to the base activation. + + /* Generates the following: + + func init() { + BuiltinEntitlements[Foo.Identifier] = Foo + addToBaseActivation(Foo) + ... + } + */ + + if len(program.EntitlementDeclarations()) == 0 { + return + } + + stmts := make([]dst.Stmt, 0) + + for _, declaration := range program.EntitlementDeclarations() { + const entitlementsName = "BuiltinEntitlements" + varName := entitlementVarName(declaration.Identifier.Identifier) + + mapUpdateStmt := &dst.AssignStmt{ + Lhs: []dst.Expr{ + &dst.IndexExpr{ + X: dst.NewIdent(entitlementsName), + Index: &dst.SelectorExpr{ + X: dst.NewIdent(varName), + Sel: dst.NewIdent("Identifier"), + }, + }, + }, + Tok: token.ASSIGN, + Rhs: []dst.Expr{ + dst.NewIdent(varName), + }, + } + + typeRegisterStmt := &dst.ExprStmt{ + X: &dst.CallExpr{ + Fun: dst.NewIdent("addToBaseActivation"), + Args: []dst.Expr{ + dst.NewIdent(varName), + }, + }, + } + + stmts = append(stmts, mapUpdateStmt, typeRegisterStmt) + } + + initDecl := &dst.FuncDecl{ + Name: dst.NewIdent("init"), + Type: &dst.FuncType{}, + Body: &dst.BlockStmt{ + List: stmts, + }, + } + + g.addDecls(initDecl) +} + func goField(name string, ty dst.Expr) *dst.Field { return &dst.Field{ Names: []*dst.Ident{ @@ -1071,6 +1155,10 @@ func typeVarName(typeName string) string { return fmt.Sprintf("%sType", typeName) } +func entitlementVarName(typeName string) string { + return fmt.Sprintf("%sEntitlement", typeName) +} + func typeVarIdent(typeName string) *dst.Ident { return dst.NewIdent(typeVarName(typeName)) } @@ -1152,6 +1240,7 @@ func simpleTypeLiteral(ty *typeDecl) dst.Expr { goKeyValue("Comparable", goBoolLit(ty.comparable)), goKeyValue("Exportable", goBoolLit(ty.exportable)), goKeyValue("Importable", goBoolLit(ty.importable)), + goKeyValue("ContainFields", goBoolLit(ty.memberAccessible)), } return &dst.UnaryExpr{ @@ -1482,6 +1571,24 @@ func typeParameterExpr(name string, typeBound dst.Expr) dst.Expr { } } +func entitlementTypeLiteral(name string) dst.Expr { + // &sema.EntitlementType{ + // Identifier: "Foo", + //} + + elements := []dst.Expr{ + goKeyValue("Identifier", goStringLit(name)), + } + + return &dst.UnaryExpr{ + Op: token.AND, + X: &dst.CompositeLit{ + Type: dst.NewIdent("EntitlementType"), + Elts: elements, + }, + } +} + func parseCadenceFile(path string) *ast.Program { program, code, err := parser.ParseProgramFromFile(nil, path, parserConfig) if err != nil { @@ -1496,7 +1603,7 @@ func parseCadenceFile(path string) *ast.Program { return program } -func gen(inPath string, outFile *os.File) { +func gen(inPath string, outFile *os.File, registerTypes bool) { program := parseCadenceFile(inPath) var gen generator @@ -1505,6 +1612,10 @@ func gen(inPath string, outFile *os.File) { _ = ast.AcceptDeclaration[struct{}](declaration, &gen) } + if registerTypes { + gen.generateTypeInit(program) + } + writeGoFile(inPath, outFile, gen.decls) } @@ -1544,5 +1655,8 @@ func main() { } defer outFile.Close() - gen(inPath, outFile) + // Register generated test types in base activation. + const registerTypes = true + + gen(inPath, outFile, registerTypes) } diff --git a/runtime/sema/gen/main_test.go b/runtime/sema/gen/main_test.go index 173adc1df1..32d2b49462 100644 --- a/runtime/sema/gen/main_test.go +++ b/runtime/sema/gen/main_test.go @@ -50,7 +50,10 @@ func TestFiles(t *testing.T) { require.NoError(t, err) defer outFile.Close() - gen(inputPath, outFile) + // Do not register generated test types in base activation. + const registerTypes = false + + gen(inputPath, outFile, registerTypes) goldenPath := filepath.Join(testDataDirectory, testname+".golden.go") want, err := os.ReadFile(goldenPath) diff --git a/runtime/sema/gen/testdata/comparable.golden.go b/runtime/sema/gen/testdata/comparable.golden.go index 4dd82cb6d0..610a531517 100644 --- a/runtime/sema/gen/testdata/comparable.golden.go +++ b/runtime/sema/gen/testdata/comparable.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: true, Exportable: false, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/docstrings.golden.go b/runtime/sema/gen/testdata/docstrings.golden.go index e27183f0e2..846a220c43 100644 --- a/runtime/sema/gen/testdata/docstrings.golden.go +++ b/runtime/sema/gen/testdata/docstrings.golden.go @@ -114,6 +114,7 @@ var DocstringsType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } func init() { diff --git a/runtime/sema/gen/testdata/entitlement.cdc b/runtime/sema/gen/testdata/entitlement.cdc new file mode 100644 index 0000000000..e31709a3ad --- /dev/null +++ b/runtime/sema/gen/testdata/entitlement.cdc @@ -0,0 +1 @@ +entitlement Foo diff --git a/runtime/sema/gen/testdata/entitlement.golden.go b/runtime/sema/gen/testdata/entitlement.golden.go new file mode 100644 index 0000000000..1da6ff2297 --- /dev/null +++ b/runtime/sema/gen/testdata/entitlement.golden.go @@ -0,0 +1,24 @@ +// Code generated from testdata/entitlement.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +var FooEntitlement = &EntitlementType{ + Identifier: "Foo", +} diff --git a/runtime/sema/gen/testdata/equatable.golden.go b/runtime/sema/gen/testdata/equatable.golden.go index 291dfaadd5..82320957ee 100644 --- a/runtime/sema/gen/testdata/equatable.golden.go +++ b/runtime/sema/gen/testdata/equatable.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/exportable.golden.go b/runtime/sema/gen/testdata/exportable.golden.go index c184ae6b2b..db6afd0593 100644 --- a/runtime/sema/gen/testdata/exportable.golden.go +++ b/runtime/sema/gen/testdata/exportable.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: true, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/fields.golden.go b/runtime/sema/gen/testdata/fields.golden.go index a361770494..9ecc3a2bc7 100644 --- a/runtime/sema/gen/testdata/fields.golden.go +++ b/runtime/sema/gen/testdata/fields.golden.go @@ -140,6 +140,7 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } func init() { diff --git a/runtime/sema/gen/testdata/functions.golden.go b/runtime/sema/gen/testdata/functions.golden.go index 0e108fb480..fafd4cfaf4 100644 --- a/runtime/sema/gen/testdata/functions.golden.go +++ b/runtime/sema/gen/testdata/functions.golden.go @@ -186,6 +186,7 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } func init() { diff --git a/runtime/sema/gen/testdata/importable.golden.go b/runtime/sema/gen/testdata/importable.golden.go index 42778d5dcb..39496013d4 100644 --- a/runtime/sema/gen/testdata/importable.golden.go +++ b/runtime/sema/gen/testdata/importable.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: true, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/member_accessible.cdc b/runtime/sema/gen/testdata/member_accessible.cdc new file mode 100644 index 0000000000..1c83dfc60a --- /dev/null +++ b/runtime/sema/gen/testdata/member_accessible.cdc @@ -0,0 +1 @@ +access(all) struct Test: ContainFields {} diff --git a/runtime/sema/gen/testdata/member_accessible.golden.go b/runtime/sema/gen/testdata/member_accessible.golden.go new file mode 100644 index 0000000000..d3553ebafd --- /dev/null +++ b/runtime/sema/gen/testdata/member_accessible.golden.go @@ -0,0 +1,36 @@ +// Code generated from testdata/member_accessible.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +const TestTypeName = "Test" + +var TestType = &SimpleType{ + Name: TestTypeName, + QualifiedName: TestTypeName, + TypeID: TestTypeName, + tag: TestTypeTag, + IsResource: false, + Storable: false, + Equatable: false, + Comparable: false, + Exportable: false, + Importable: false, + ContainFields: true, +} diff --git a/runtime/sema/gen/testdata/simple-resource.golden.go b/runtime/sema/gen/testdata/simple-resource.golden.go index 4028abf652..53ee6c24f7 100644 --- a/runtime/sema/gen/testdata/simple-resource.golden.go +++ b/runtime/sema/gen/testdata/simple-resource.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/simple-struct.golden.go b/runtime/sema/gen/testdata/simple-struct.golden.go index 04cf626437..0429d008f3 100644 --- a/runtime/sema/gen/testdata/simple-struct.golden.go +++ b/runtime/sema/gen/testdata/simple-struct.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/gen/testdata/storable.golden.go b/runtime/sema/gen/testdata/storable.golden.go index adcaf51f6c..c9c5526991 100644 --- a/runtime/sema/gen/testdata/storable.golden.go +++ b/runtime/sema/gen/testdata/storable.golden.go @@ -32,4 +32,5 @@ var TestType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: false, } diff --git a/runtime/sema/simple_type.go b/runtime/sema/simple_type.go index 9418e9a0bb..f137fdf61a 100644 --- a/runtime/sema/simple_type.go +++ b/runtime/sema/simple_type.go @@ -50,6 +50,7 @@ type SimpleType struct { Comparable bool Storable bool IsResource bool + ContainFields bool } var _ Type = &SimpleType{} @@ -106,6 +107,10 @@ func (t *SimpleType) IsImportable(_ map[*Member]bool) bool { return t.Importable } +func (t *SimpleType) ContainFieldsOrElements() bool { + return t.ContainFields +} + func (*SimpleType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } diff --git a/runtime/sema/storage_capability_controller.cdc b/runtime/sema/storage_capability_controller.cdc index 7d70961630..665c4a73c4 100644 --- a/runtime/sema/storage_capability_controller.cdc +++ b/runtime/sema/storage_capability_controller.cdc @@ -1,4 +1,4 @@ -access(all) struct StorageCapabilityController { +access(all) struct StorageCapabilityController: ContainFields { /// An arbitrary "tag" for the controller. /// For example, it could be used to describe the purpose of the capability. diff --git a/runtime/sema/storage_capability_controller.gen.go b/runtime/sema/storage_capability_controller.gen.go index ac353f248e..801ae613b9 100644 --- a/runtime/sema/storage_capability_controller.gen.go +++ b/runtime/sema/storage_capability_controller.gen.go @@ -133,6 +133,7 @@ var StorageCapabilityControllerType = &SimpleType{ Comparable: false, Exportable: false, Importable: false, + ContainFields: true, } func init() { diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 11ac4a3695..c7d61d5293 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -135,6 +135,27 @@ type Type interface { // IsComparable returns true if values of the type can be compared IsComparable() bool + // ContainFieldsOrElements returns true if value of the type can have nested values (fields or elements). + // This notion is to indicate that a type can be used to access its nested values using + // either index-expression or member-expression. e.g. `foo.bar` or `foo[bar]`. + // This is used to determine if a field/element of this type should be returning a reference or not. + // + // Only a subset of types has this characteristic. e.g: + // - Composites + // - Interfaces + // - Arrays (Variable/Constant sized) + // - Dictionaries + // - Restricted types + // - Optionals of the above. + // - Then there are also built-in simple types, like StorageCapabilityControllerType, BlockType, etc. + // where the type is implemented as a simple type, but they also have fields. + // + // This is different from the existing `ValueIndexableType` in the sense that it is also implemented by simple types + // but not all simple types are indexable. + // On the other-hand, some indexable types (e.g. String) shouldn't be treated/returned as references. + // + ContainFieldsOrElements() bool + TypeAnnotationState() TypeAnnotationState RewriteWithIntersectionTypes() (result Type, rewritten bool) @@ -204,8 +225,7 @@ type MemberResolver struct { targetRange ast.Range, report func(error), ) *Member - Kind common.DeclarationKind - Mutating bool + Kind common.DeclarationKind } // supertype of interfaces and composites @@ -638,6 +658,10 @@ func (*OptionalType) IsComparable() bool { return false } +func (t *OptionalType) ContainFieldsOrElements() bool { + return t.Type.ContainFieldsOrElements() +} + func (t *OptionalType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -843,6 +867,10 @@ func (*GenericType) IsComparable() bool { return false } +func (t *GenericType) ContainFieldsOrElements() bool { + return false +} + func (*GenericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -1139,6 +1167,10 @@ func (t *NumericType) IsComparable() bool { return !t.IsSuperType() } +func (t *NumericType) ContainFieldsOrElements() bool { + return false +} + func (*NumericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -1324,6 +1356,10 @@ func (t *FixedPointNumericType) IsComparable() bool { return !t.IsSuperType() } +func (t *FixedPointNumericType) ContainFieldsOrElements() bool { + return false +} + func (*FixedPointNumericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -1853,6 +1889,7 @@ const UFix64TypeMaxFractional = fixedpoint.UFix64TypeMaxFractional type ArrayType interface { ValueIndexableType + EntitlementSupportingType isArrayType() } @@ -1926,6 +1963,16 @@ Returns a new array with contents in the reversed order. Available if the array element type is not resource-kinded. ` +var insertableEntitledAccess = NewEntitlementSetAccess( + []*EntitlementType{InsertEntitlement, MutateEntitlement}, + Disjunction, +) + +var removableEntitledAccess = NewEntitlementSetAccess( + []*EntitlementType{RemoveEntitlement, MutateEntitlement}, + Disjunction, +) + func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { members := map[string]MemberResolver{ @@ -2051,13 +2098,13 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { if _, ok := arrayType.(*VariableSizedType); ok { members["append"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member { elementType := arrayType.ElementType(false) - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + insertableEntitledAccess, identifier, ArrayAppendFunctionType(elementType), arrayTypeAppendFunctionDocString, @@ -2066,8 +2113,7 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { } members["appendAll"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member { elementType := arrayType.ElementType(false) @@ -2082,9 +2128,10 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { ) } - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + insertableEntitledAccess, identifier, ArrayAppendAllFunctionType(arrayType), arrayTypeAppendAllFunctionDocString, @@ -2147,15 +2194,15 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { } members["insert"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { elementType := arrayType.ElementType(false) - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + insertableEntitledAccess, identifier, ArrayInsertFunctionType(elementType), arrayTypeInsertFunctionDocString, @@ -2164,15 +2211,15 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { } members["remove"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { elementType := arrayType.ElementType(false) - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + removableEntitledAccess, identifier, ArrayRemoveFunctionType(elementType), arrayTypeRemoveFunctionDocString, @@ -2181,33 +2228,32 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { } members["removeFirst"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { elementType := arrayType.ElementType(false) - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + removableEntitledAccess, identifier, ArrayRemoveFirstFunctionType(elementType), - arrayTypeRemoveFirstFunctionDocString, ) }, } members["removeLast"] = MemberResolver{ - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { elementType := arrayType.ElementType(false) - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, arrayType, + removableEntitledAccess, identifier, ArrayRemoveLastFunctionType(elementType), arrayTypeRemoveLastFunctionDocString, @@ -2373,6 +2419,7 @@ type VariableSizedType struct { var _ Type = &VariableSizedType{} var _ ArrayType = &VariableSizedType{} var _ ValueIndexableType = &VariableSizedType{} +var _ EntitlementSupportingType = &VariableSizedType{} func NewVariableSizedType(memoryGauge common.MemoryGauge, typ Type) *VariableSizedType { common.UseMemory(memoryGauge, common.VariableSizedSemaTypeMemoryUsage) @@ -2453,6 +2500,10 @@ func (t *VariableSizedType) IsComparable() bool { return t.Type.IsComparable() } +func (t *VariableSizedType) ContainFieldsOrElements() bool { + return true +} + func (t *VariableSizedType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -2510,6 +2561,18 @@ func (t *VariableSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap) } } +func (t *VariableSizedType) SupportedEntitlements() *EntitlementOrderedSet { + return arrayDictionaryEntitlements +} + +var arrayDictionaryEntitlements = func() *EntitlementOrderedSet { + set := orderedmap.New[EntitlementOrderedSet](3) + set.Set(MutateEntitlement, struct{}{}) + set.Set(InsertEntitlement, struct{}{}) + set.Set(RemoveEntitlement, struct{}{}) + return set +}() + // ConstantSizedType is a constant sized array type type ConstantSizedType struct { Type Type @@ -2521,6 +2584,7 @@ type ConstantSizedType struct { var _ Type = &ConstantSizedType{} var _ ArrayType = &ConstantSizedType{} var _ ValueIndexableType = &ConstantSizedType{} +var _ EntitlementSupportingType = &ConstantSizedType{} func NewConstantSizedType(memoryGauge common.MemoryGauge, typ Type, size int64) *ConstantSizedType { common.UseMemory(memoryGauge, common.ConstantSizedSemaTypeMemoryUsage) @@ -2603,6 +2667,10 @@ func (t *ConstantSizedType) IsComparable() bool { return t.Type.IsComparable() } +func (t *ConstantSizedType) ContainFieldsOrElements() bool { + return true +} + func (t *ConstantSizedType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -2666,6 +2734,10 @@ func (t *ConstantSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap) } } +func (t *ConstantSizedType) SupportedEntitlements() *EntitlementOrderedSet { + return arrayDictionaryEntitlements +} + // Parameter func formatParameter(spaces bool, label, identifier, typeAnnotation string) string { @@ -3148,6 +3220,10 @@ func (*FunctionType) IsComparable() bool { return false } +func (*FunctionType) ContainFieldsOrElements() bool { + return false +} + func (t *FunctionType) TypeAnnotationState() TypeAnnotationState { for _, typeParameter := range t.TypeParameters { @@ -3481,20 +3557,11 @@ func init() { ) for _, ty := range types { - typeName := ty.String() - - // Check that the type is not accidentally redeclared - - if BaseTypeActivation.Find(typeName) != nil { - panic(errors.NewUnreachableError()) - } - - BaseTypeActivation.Set( - typeName, - baseTypeVariable(typeName, ty), - ) + addToBaseActivation(ty) } + addToBaseActivation(IdentityMappingType) + // The AST contains empty type annotations, resolve them to Void BaseTypeActivation.Set( @@ -3503,6 +3570,23 @@ func init() { ) } +func addToBaseActivation(ty Type) { + typeName := ty.String() + + // Check that the type is not accidentally redeclared + + if BaseTypeActivation.Find(typeName) != nil { + panic(errors.NewUnreachableError()) + } + + BaseTypeActivation.Set( + typeName, + baseTypeVariable(typeName, ty), + ) +} + +var IdentityMappingType = NewEntitlementMapType(nil, nil, "Identity") + func baseTypeVariable(name string, ty Type) *Variable { return &Variable{ Identifier: name, @@ -3580,6 +3664,12 @@ var AllNumberTypes = common.Concat( }, ) +var BuiltinEntitlements = map[string]*EntitlementType{} + +var BuiltinEntitlementMappings = map[string]*EntitlementMapType{ + IdentityMappingType.QualifiedIdentifier(): IdentityMappingType, +} + const NumberTypeMinFieldName = "min" const NumberTypeMaxFieldName = "max" @@ -4328,8 +4418,12 @@ func (*CompositeType) IsComparable() bool { return false } -func (c *CompositeType) TypeAnnotationState() TypeAnnotationState { - if c.Kind == common.CompositeKindAttachment { +func (*CompositeType) ContainFieldsOrElements() bool { + return true +} + +func (t *CompositeType) TypeAnnotationState() TypeAnnotationState { + if t.Kind == common.CompositeKindAttachment { return TypeAnnotationStateDirectAttachmentTypeAnnotation } return TypeAnnotationStateValid @@ -5032,6 +5126,10 @@ func (*InterfaceType) IsComparable() bool { return false } +func (*InterfaceType) ContainFieldsOrElements() bool { + return true +} + func (*InterfaceType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -5160,6 +5258,7 @@ type DictionaryType struct { var _ Type = &DictionaryType{} var _ ValueIndexableType = &DictionaryType{} +var _ EntitlementSupportingType = &DictionaryType{} func NewDictionaryType(memoryGauge common.MemoryGauge, keyType, valueType Type) *DictionaryType { common.UseMemory(memoryGauge, common.DictionarySemaTypeMemoryUsage) @@ -5243,6 +5342,10 @@ func (*DictionaryType) IsComparable() bool { return false } +func (*DictionaryType) ContainFieldsOrElements() bool { + return true +} + func (t *DictionaryType) TypeAnnotationState() TypeAnnotationState { keyTypeAnnotationState := t.KeyType.TypeAnnotationState() if keyTypeAnnotationState != TypeAnnotationStateValid { @@ -5398,12 +5501,12 @@ func (t *DictionaryType) initializeMemberResolvers() { }, }, "insert": { - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, t, + insertableEntitledAccess, identifier, DictionaryInsertFunctionType(t), dictionaryTypeInsertFunctionDocString, @@ -5411,12 +5514,12 @@ func (t *DictionaryType) initializeMemberResolvers() { }, }, "remove": { - Kind: common.DeclarationKindFunction, - Mutating: true, + Kind: common.DeclarationKindFunction, Resolve: func(memoryGauge common.MemoryGauge, identifier string, _ ast.Range, _ func(error)) *Member { - return NewPublicFunctionMember( + return NewFunctionMember( memoryGauge, t, + removableEntitledAccess, identifier, DictionaryRemoveFunctionType(t), dictionaryTypeRemoveFunctionDocString, @@ -5586,6 +5689,10 @@ func (t *DictionaryType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Typ } } +func (t *DictionaryType) SupportedEntitlements() *EntitlementOrderedSet { + return arrayDictionaryEntitlements +} + // ReferenceType represents the reference to a value type ReferenceType struct { Type Type @@ -5708,6 +5815,10 @@ func (*ReferenceType) IsComparable() bool { return false } +func (*ReferenceType) ContainFieldsOrElements() bool { + return false +} + func (r *ReferenceType) TypeAnnotationState() TypeAnnotationState { if r.Type.TypeAnnotationState() == TypeAnnotationStateDirectEntitlementTypeAnnotation { return TypeAnnotationStateDirectEntitlementTypeAnnotation @@ -5907,6 +6018,10 @@ func (*AddressType) IsComparable() bool { return false } +func (*AddressType) ContainFieldsOrElements() bool { + return false +} + func (*AddressType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -6511,6 +6626,10 @@ func (*TransactionType) IsComparable() bool { return false } +func (*TransactionType) ContainFieldsOrElements() bool { + return false +} + func (*TransactionType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -6725,6 +6844,10 @@ func (t *IntersectionType) IsComparable() bool { return false } +func (*IntersectionType) ContainFieldsOrElements() bool { + return true +} + func (*IntersectionType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -6957,6 +7080,10 @@ func (*CapabilityType) IsComparable() bool { return false } +func (*CapabilityType) ContainFieldsOrElements() bool { + return false +} + func (t *CapabilityType) RewriteWithIntersectionTypes() (Type, bool) { if t.BorrowType == nil { return t, false @@ -7510,6 +7637,10 @@ func (*EntitlementType) IsResourceType() bool { return false } +func (*EntitlementType) ContainFieldsOrElements() bool { + return false +} + func (*EntitlementType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateDirectEntitlementTypeAnnotation } @@ -7640,6 +7771,10 @@ func (*EntitlementMapType) IsResourceType() bool { return false } +func (*EntitlementMapType) ContainFieldsOrElements() bool { + return false +} + func (*EntitlementMapType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateDirectEntitlementTypeAnnotation } diff --git a/runtime/sema/type_test.go b/runtime/sema/type_test.go index 2552d49fca..75faf04bf0 100644 --- a/runtime/sema/type_test.go +++ b/runtime/sema/type_test.go @@ -668,7 +668,16 @@ func TestCommonSuperType(t *testing.T) { var tests []testCase err := BaseTypeActivation.ForEach(func(name string, variable *Variable) error { + // Entitlements are not typical types. So skip. + if _, ok := BuiltinEntitlements[name]; ok { + return nil + } + if _, ok := BuiltinEntitlementMappings[name]; ok { + return nil + } + typ := variable.Type + tests = append(tests, testCase{ name: name, types: []Type{ @@ -1770,6 +1779,14 @@ func TestTypeInclusions(t *testing.T) { t.Parallel() err := BaseTypeActivation.ForEach(func(name string, variable *Variable) error { + // Entitlements are not typical types. So skip. + if _, ok := BuiltinEntitlements[name]; ok { + return nil + } + if _, ok := BuiltinEntitlementMappings[name]; ok { + return nil + } + t.Run(name, func(t *testing.T) { typ := variable.Type @@ -1790,6 +1807,14 @@ func TestTypeInclusions(t *testing.T) { t.Parallel() err := BaseTypeActivation.ForEach(func(name string, variable *Variable) error { + // Entitlements are not typical types. So skip. + if _, ok := BuiltinEntitlements[name]; ok { + return nil + } + if _, ok := BuiltinEntitlementMappings[name]; ok { + return nil + } + t.Run(name, func(t *testing.T) { typ := variable.Type diff --git a/runtime/storage_test.go b/runtime/storage_test.go index e66bd0d845..4a26716014 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -2276,7 +2276,7 @@ func TestRuntimeReferenceOwnerAccess(t *testing.T) { account.capabilities.publish(cap, at: /public/test) let ref2 = account.capabilities.borrow<&[TestContract.TestResource]>(/public/test)! - let ref3 = &ref2[0] as &TestContract.TestResource + let ref3 = ref2[0] log(ref3.owner?.address) } } @@ -2414,7 +2414,7 @@ func TestRuntimeReferenceOwnerAccess(t *testing.T) { account.capabilities.publish(cap, at: /public/test) nestingResourceRef = account.capabilities.borrow<&TestContract.TestNestingResource>(/public/test)! - nestedElementResourceRef = &nestingResourceRef.nestedResources[0] as &TestContract.TestNestedResource + nestedElementResourceRef = nestingResourceRef.nestedResources[0] log(nestingResourceRef.owner?.address) log(nestedElementResourceRef.owner?.address) @@ -2542,7 +2542,7 @@ func TestRuntimeReferenceOwnerAccess(t *testing.T) { account.capabilities.publish(cap, at: /public/test) let testResourcesRef = account.capabilities.borrow<&[[TestContract.TestResource]]>(/public/test)! - ref = &testResourcesRef[0] as &[TestContract.TestResource] + ref = testResourcesRef[0] log(ref[0].owner?.address) } } @@ -2661,12 +2661,13 @@ func TestRuntimeReferenceOwnerAccess(t *testing.T) { account.save(<-testResources, to: /storage/test) + // At this point the resource is in storage let cap = account.capabilities.storage.issue<&[{Int: TestContract.TestResource}]>(/storage/test) account.capabilities.publish(cap, at: /public/test) let testResourcesRef = account.capabilities.borrow<&[{Int: TestContract.TestResource}]>(/public/test)! - ref = &testResourcesRef[0] as &{Int: TestContract.TestResource} + ref = testResourcesRef[0] log(ref[0]?.owner?.address) } } diff --git a/runtime/tests/checker/arrays_dictionaries_test.go b/runtime/tests/checker/arrays_dictionaries_test.go index 97fd02aa05..2aa5a296c0 100644 --- a/runtime/tests/checker/arrays_dictionaries_test.go +++ b/runtime/tests/checker/arrays_dictionaries_test.go @@ -1534,3 +1534,747 @@ func TestNilAssignmentToDictionary(t *testing.T) { require.NoError(t, err) }) } + +func TestCheckArrayFunctionEntitlements(t *testing.T) { + t.Parallel() + + t.Run("inserting functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + } + `) + + errors := RequireCheckerErrors(t, err, 3) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert) &[String] + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + } + `) + + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Remove) &[String] + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + } + `) + + errors := RequireCheckerErrors(t, err, 3) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + }) + + t.Run("removing functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + errors := RequireCheckerErrors(t, err, 3) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert) &[String] + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + errors := RequireCheckerErrors(t, err, 3) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Remove) &[String] + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("public functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 2, upTo: 4) + arrayRef.concat(["hello"]) + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 2, upTo: 4) + arrayRef.concat(["hello"]) + } + `) + + require.NoError(t, err) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert) &[String] + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 2, upTo: 4) + arrayRef.concat(["hello"]) + } + `) + + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Remove) &[String] + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 2, upTo: 4) + arrayRef.concat(["hello"]) + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("assignment", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + arrayRef[0] = "baz" + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + arrayRef[0] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a non-auth reference", + ) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert) &[String] + arrayRef[0] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a (Insert) reference", + ) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Remove) &[String] + arrayRef[0] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a (Remove) reference", + ) + }) + + t.Run("insertable and removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert, Remove) &[String] + arrayRef[0] = "baz" + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("swap", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + arrayRef[0] <-> arrayRef[1] + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + arrayRef[0] <-> arrayRef[1] + } + `) + + errors := RequireCheckerErrors(t, err, 2) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + }) +} + +func TestCheckDictionaryFunctionEntitlements(t *testing.T) { + t.Parallel() + + t.Run("inserting functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + dictionaryRef.insert(key: "three", "baz") + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef.insert(key: "three", "baz") + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Insert) &{String: String} + dictionaryRef.insert(key: "three", "baz") + } + `) + + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef.insert(key: "three", "baz") + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + }) + }) + + t.Run("removing functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + dictionaryRef.remove(key: "foo") + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef.remove(key: "foo") + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Insert) &{String: String} + dictionaryRef.remove(key: "foo") + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.InvalidAccessError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Remove) &{String: String} + dictionaryRef.remove(key: "foo") + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("public functions", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + } + `) + + require.NoError(t, err) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Insert) &{String: String} + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + } + `) + + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Remove) &{String: String} + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("assignment", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + dictionaryRef["three"] = "baz" + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef["three"] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a non-auth reference", + ) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Remove) &{String: String} + dictionaryRef["three"] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a (Remove) reference", + ) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Insert) &{String: String} + dictionaryRef["three"] = "baz" + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + + assert.Contains( + t, + errors[0].Error(), + "can only assign to a reference with (Mutate) or (Insert, Remove) access, but found a (Insert) reference", + ) + }) + + t.Run("insertable and removable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Insert, Remove) &{String: String} + dictionaryRef["three"] = "baz" + } + `) + + require.NoError(t, err) + }) + }) + + t.Run("swap", func(t *testing.T) { + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: AnyStruct} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: AnyStruct} + dictionaryRef["one"] <-> dictionaryRef["two"] + } + `) + + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + dictionaryRef["one"] <-> dictionaryRef["two"] + } + `) + + errors := RequireCheckerErrors(t, err, 2) + + var invalidAccessError = &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errors[0], &invalidAccessError) + assert.ErrorAs(t, errors[1], &invalidAccessError) + }) + }) +} diff --git a/runtime/tests/checker/attachments_test.go b/runtime/tests/checker/attachments_test.go index c50a66eb5a..d96a244bf2 100644 --- a/runtime/tests/checker/attachments_test.go +++ b/runtime/tests/checker/attachments_test.go @@ -4016,7 +4016,7 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { ` access(all) resource R {} access(all) attachment A for R { - access(all) let x: [String] + access(all) let x: [String] init() { self.x = ["x"] } @@ -4030,7 +4030,37 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { ) errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + assert.IsType(t, &sema.InvalidAccessError{}, errs[0]) + }) + + t.Run("basic, with entitlements", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, + ` + access(all) resource R {} + + entitlement mapping M { + Mutate -> Insert + } + + access(M) attachment A for R { + access(Identity) let x: [String] + init() { + self.x = ["x"] + } + } + + fun main(r: @R) { + var xRef = r[A]!.x + xRef.append("y") + destroy r + } + `, + ) + + require.NoError(t, err) }) t.Run("in base", func(t *testing.T) { @@ -4045,7 +4075,7 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { } } access(all) attachment A for R { - access(all) let x: [String] + access(all) let x: [String] init() { self.x = ["x"] } @@ -4055,7 +4085,35 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { ) errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, &sema.ExternalMutationError{}, errs[0]) + assert.IsType(t, &sema.InvalidAccessError{}, errs[0]) + }) + + t.Run("in base, with entitlements", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, + ` + entitlement mapping M { + Mutate -> Insert + } + + access(all) resource R { + access(all) fun foo() { + var xRef = self[A]!.x + xRef.append("y") + } + } + access(M) attachment A for R { + access(Identity) let x: [String] + init() { + self.x = ["x"] + } + } + `, + ) + + require.NoError(t, err) }) t.Run("in self, through base", func(t *testing.T) { @@ -4066,7 +4124,7 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { ` access(all) resource R {} access(all) attachment A for R { - access(all) let x: [String] + access(all) let x: [String] init() { self.x = ["x"] } @@ -4078,7 +4136,8 @@ func TestCheckAttachmentsExternalMutation(t *testing.T) { `, ) - require.NoError(t, err) + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.InvalidAccessError{}, errs[0]) }) } diff --git a/runtime/tests/checker/entitlements_test.go b/runtime/tests/checker/entitlements_test.go index aee60d9740..c4b0384b95 100644 --- a/runtime/tests/checker/entitlements_test.go +++ b/runtime/tests/checker/entitlements_test.go @@ -600,6 +600,18 @@ func TestCheckBasicEntitlementMappingAccess(t *testing.T) { require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errs[0]) }) + t.Run("non-reference container field", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` + entitlement mapping M {} + struct interface S { + access(M) let foo: [String] + } + `) + + assert.NoError(t, err) + }) + t.Run("mismatched entitlement mapping", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` @@ -1321,6 +1333,18 @@ func TestCheckBasicEntitlementMappingAccess(t *testing.T) { assert.NoError(t, err) }) + + t.Run("ref array field", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` + entitlement mapping M {} + resource interface R { + access(M) let foo: [auth(M) &Int] + } + `) + + assert.NoError(t, err) + }) } func TestCheckInvalidEntitlementAccess(t *testing.T) { @@ -1540,20 +1564,6 @@ func TestCheckInvalidEntitlementMappingAuth(t *testing.T) { require.IsType(t, &sema.InvalidMappedAuthorizationOutsideOfFieldError{}, errs[0]) }) - t.Run("ref array field", func(t *testing.T) { - t.Parallel() - _, err := ParseAndCheck(t, ` - entitlement mapping M {} - resource interface R { - access(M) let foo: [auth(M) &Int] - } - `) - - errs := RequireCheckerErrors(t, err, 1) - - require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errs[0]) - }) - t.Run("capability field", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` @@ -4305,8 +4315,7 @@ func TestCheckAttachmentEntitlements(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.InvalidMappedEntitlementMemberError{}, errs[0]) + assert.NoError(t, err) }) t.Run("access(all) decl", func(t *testing.T) { @@ -4549,6 +4558,7 @@ func TestCheckAttachmentAccessEntitlements(t *testing.T) { func TestCheckEntitlementConditions(t *testing.T) { t.Parallel() + t.Run("use of function on owned value", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` @@ -4772,6 +4782,72 @@ func TestCheckEntitlementConditions(t *testing.T) { assert.NoError(t, err) }) + + t.Run("result value usage, variable-sized resource array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun foo(r: @[R]): @[R] { + post { + bar(result): "" + } + return <-r + } + + // 'result' variable should have all the entitlements available for arrays. + view fun bar(_ r: auth(Mutate, Insert, Remove) &[R]): Bool { + return true + } + `) + + assert.NoError(t, err) + }) + + t.Run("result value usage, constant-sized resource array", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun foo(r: @[R; 5]): @[R; 5] { + post { + bar(result): "" + } + return <-r + } + + // 'result' variable should have all the entitlements available for arrays. + view fun bar(_ r: auth(Mutate, Insert, Remove) &[R; 5]): Bool { + return true + } + `) + + assert.NoError(t, err) + }) + + t.Run("result value usage, resource dictionary", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun foo(r: @{String:R}): @{String:R} { + post { + bar(result): "" + } + return <-r + } + + // 'result' variable should have all the entitlements available for dictionaries. + view fun bar(_ r: auth(Mutate, Insert, Remove) &{String:R}): Bool { + return true + } + `) + + assert.NoError(t, err) + }) } func TestCheckEntitledWriteAndMutateNotAllowed(t *testing.T) { @@ -4884,8 +4960,7 @@ func TestCheckEntitledWriteAndMutateNotAllowed(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.ExternalMutationError{}, errs[0]) + assert.NoError(t, err) }) t.Run("basic authorized", func(t *testing.T) { @@ -4906,7 +4981,7 @@ func TestCheckEntitledWriteAndMutateNotAllowed(t *testing.T) { `) errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.ExternalMutationError{}, errs[0]) + assert.IsType(t, &sema.InvalidAccessError{}, errs[0]) }) } @@ -5255,3 +5330,321 @@ func TestCheckAttachProvidedEntitlements(t *testing.T) { require.Equal(t, errs[1].(*sema.RequiredEntitlementNotProvidedError).RequiredEntitlement.Identifier, "E") }) } + +func TestCheckBuiltinEntitlements(t *testing.T) { + + t.Parallel() + + t.Run("builtin", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + access(Mutate) fun foo() {} + access(Insert) fun bar() {} + access(Remove) fun baz() {} + } + + fun main() { + let s = S() + let mutableRef = &s as auth(Mutate) &S + let insertableRef = &s as auth(Insert) &S + let removableRef = &s as auth(Remove) &S + } + `) + + assert.NoError(t, err) + }) + + t.Run("redefine", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement Mutate + entitlement Insert + entitlement Remove + `) + + errs := RequireCheckerErrors(t, err, 3) + + require.IsType(t, &sema.RedeclarationError{}, errs[0]) + require.IsType(t, &sema.RedeclarationError{}, errs[1]) + require.IsType(t, &sema.RedeclarationError{}, errs[2]) + }) + +} + +func TestCheckIdentityMapping(t *testing.T) { + + t.Parallel() + + t.Run("owned value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + // OK + let resultRef1: &AnyStruct = s.foo() + + // Error: Must return an unauthorized ref + let resultRef2: auth(Mutate) &AnyStruct = s.foo() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + + require.IsType(t, &sema.ReferenceType{}, typeMismatchError.ActualType) + actualReference := typeMismatchError.ActualType.(*sema.ReferenceType) + + require.IsType(t, sema.EntitlementSetAccess{}, actualReference.Authorization) + actualAuth := actualReference.Authorization.(sema.EntitlementSetAccess) + + assert.Equal(t, 0, actualAuth.Entitlements.Len()) + }) + + t.Run("unauthorized ref", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let ref = &s as &S + + // OK + let resultRef1: &AnyStruct = ref.foo() + + // Error: Must return an unauthorized ref + let resultRef2: auth(Mutate) &AnyStruct = ref.foo() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + }) + + t.Run("basic entitled ref", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let mutableRef = &s as auth(Mutate) &S + let ref1: auth(Mutate) &AnyStruct = mutableRef.foo() + + let insertableRef = &s as auth(Insert) &S + let ref2: auth(Insert) &AnyStruct = insertableRef.foo() + + let removableRef = &s as auth(Remove) &S + let ref3: auth(Remove) &AnyStruct = removableRef.foo() + } + `) + + assert.NoError(t, err) + }) + + t.Run("entitlement set ref", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let ref1 = &s as auth(Insert | Remove) &S + let resultRef1: auth(Insert | Remove) &AnyStruct = ref1.foo() + + let ref2 = &s as auth(Insert, Remove) &S + let resultRef2: auth(Insert, Remove) &AnyStruct = ref2.foo() + } + `) + + assert.NoError(t, err) + }) + + t.Run("owned value, with entitlements", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement C + + struct X { + access(A | B) var s: String + + init() { + self.s = "hello" + } + + access(C) fun foo() {} + } + + struct Y { + + // Reference + access(Identity) var x1: auth(Identity) &X + + // Optional reference + access(Identity) var x2: auth(Identity) &X? + + // Function returning a reference + access(Identity) fun getX(): auth(Identity) &X { + let x = X() + return &x as auth(Identity) &X + } + + // Function returning an optional reference + access(Identity) fun getOptionalX(): auth(Identity) &X? { + let x: X? = X() + return &x as auth(Identity) &X? + } + + init() { + let x = X() + self.x1 = &x as auth(A, B, C) &X + self.x2 = nil + } + } + + fun main() { + let y = Y() + + let ref1: auth(A, B, C) &X = y.x1 + + let ref2: auth(A, B, C) &X? = y.x2 + + let ref3: auth(A, B, C) &X = y.getX() + + let ref4: auth(A, B, C) &X? = y.getOptionalX() + } + `) + + assert.NoError(t, err) + }) + + t.Run("owned value, with entitlements, function typed field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement C + + struct X { + access(A | B) var s: String + + init() { + self.s = "hello" + } + + access(C) fun foo() {} + } + + struct Y { + + access(Identity) let fn: (fun (): X) + + init() { + self.fn = fun(): X { + return X() + } + } + } + + fun main() { + let y = Y() + let v = y.fn() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + invalidMapping := &sema.InvalidMappedEntitlementMemberError{} + require.ErrorAs(t, errors[0], &invalidMapping) + }) + + t.Run("owned value, with entitlements, function ref typed field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement C + + struct X { + access(A | B) var s: String + + init() { + self.s = "hello" + } + + access(C) fun foo() {} + } + + struct Y { + + access(Identity) let fn: auth(Identity) &(fun (): X)? + + init() { + self.fn = nil + } + } + + fun main() { + let y = Y() + let v: auth(A, B, C) &(fun (): X) = y.fn + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + + actualType := typeMismatchError.ActualType + require.IsType(t, &sema.OptionalType{}, actualType) + optionalType := actualType.(*sema.OptionalType) + + require.IsType(t, &sema.ReferenceType{}, optionalType.Type) + referenceType := optionalType.Type.(*sema.ReferenceType) + + require.IsType(t, sema.EntitlementSetAccess{}, referenceType.Authorization) + auth := referenceType.Authorization.(sema.EntitlementSetAccess) + + // Entitlements of function return type `X` must NOT be + // available for the reference typed field. + require.Equal(t, 0, auth.Entitlements.Len()) + }) +} diff --git a/runtime/tests/checker/external_mutation_test.go b/runtime/tests/checker/external_mutation_test.go index 50744ad276..71ea9fe869 100644 --- a/runtime/tests/checker/external_mutation_test.go +++ b/runtime/tests/checker/external_mutation_test.go @@ -22,6 +22,7 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/onflow/cadence/runtime/ast" @@ -83,9 +84,7 @@ func TestCheckArrayUpdateIndexAccess(t *testing.T) { `, valueKind.Keyword(), access.Keyword(), declaration.Keywords(), assignmentOp, destroyStatement), ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } @@ -152,9 +151,7 @@ func TestCheckDictionaryUpdateIndexAccess(t *testing.T) { `, valueKind.Keyword(), access.Keyword(), declaration.Keywords(), assignmentOp, destroyStatement), ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } @@ -215,9 +212,7 @@ func TestCheckNestedArrayUpdateIndexAccess(t *testing.T) { `, access.Keyword(), declaration.Keywords()), ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } @@ -276,9 +271,7 @@ func TestCheckNestedDictionaryUpdateIndexAccess(t *testing.T) { `, access.Keyword(), declaration.Keywords()), ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } @@ -327,18 +320,15 @@ func TestCheckMutateContractIndexAccess(t *testing.T) { `, access.Keyword(), declaration.Keywords()), ) - expectedErrors := 1 - if access == ast.AccessContract { - expectedErrors++ - } + expectError := access == ast.AccessContract - errs := RequireCheckerErrors(t, err, expectedErrors) - if expectedErrors > 1 { + if expectError { + errs := RequireCheckerErrors(t, err, 1) var accessError *sema.InvalidAccessError - require.ErrorAs(t, errs[expectedErrors-2], &accessError) + require.ErrorAs(t, errs[0], &accessError) + } else { + require.NoError(t, err) } - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[expectedErrors-1], &externalMutationError) }) } @@ -394,18 +384,15 @@ func TestCheckContractNestedStructIndexAccess(t *testing.T) { `, access.Keyword(), declaration.Keywords()), ) - expectedErrors := 1 - if access == ast.AccessContract { - expectedErrors++ - } + expectError := access == ast.AccessContract - errs := RequireCheckerErrors(t, err, expectedErrors) - if expectedErrors > 1 { + if expectError { + errs := RequireCheckerErrors(t, err, 1) var accessError *sema.InvalidAccessError - require.ErrorAs(t, errs[expectedErrors-2], &accessError) + require.ErrorAs(t, errs[0], &accessError) + } else { + require.NoError(t, err) } - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[expectedErrors-1], &externalMutationError) }) } @@ -458,9 +445,7 @@ func TestCheckContractStructInitIndexAccess(t *testing.T) { `, access.Keyword(), declaration.Keywords()), ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } @@ -543,13 +528,7 @@ func TestCheckArrayUpdateMethodCall(t *testing.T) { `, valueKind.Keyword(), access.Keyword(), declaration.Keywords(), assignmentOp, member.Code, destroyStatement), ) - if member.Mutating { - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) - } else { - require.NoError(t, err) - } + require.NoError(t, err) }) } @@ -632,13 +611,7 @@ func TestCheckDictionaryUpdateMethodCall(t *testing.T) { `, valueKind.Keyword(), access.Keyword(), declaration.Keywords(), assignmentOp, member.Code, destroyStatement), ) - if member.Mutating { - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) - } else { - require.NoError(t, err) - } + require.NoError(t, err) }) } @@ -715,8 +688,7 @@ func TestCheckMutationThroughReference(t *testing.T) { `, ) errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + assert.IsType(t, &sema.InvalidAccessError{}, errs[0]) }) } @@ -732,7 +704,7 @@ func TestCheckMutationThroughInnerReference(t *testing.T) { ` access(all) fun main() { let foo = Foo() - var arrayRef = &foo.ref.arr as &[String] + var arrayRef = foo.ref.arr arrayRef[0] = "y" } @@ -751,7 +723,9 @@ func TestCheckMutationThroughInnerReference(t *testing.T) { } `, ) - require.NoError(t, err) + + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.UnauthorizedReferenceAssignmentError{}, errs[0]) }) } @@ -790,8 +764,6 @@ func TestCheckMutationThroughAccess(t *testing.T) { } `, ) - errs := RequireCheckerErrors(t, err, 1) - var externalMutationError *sema.ExternalMutationError - require.ErrorAs(t, errs[0], &externalMutationError) + require.NoError(t, err) }) } diff --git a/runtime/tests/checker/member_test.go b/runtime/tests/checker/member_test.go index 2221b31a63..6d35d7b4c2 100644 --- a/runtime/tests/checker/member_test.go +++ b/runtime/tests/checker/member_test.go @@ -19,11 +19,13 @@ package checker import ( + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" ) @@ -418,3 +420,536 @@ func TestCheckMemberNotDeclaredSecondaryError(t *testing.T) { assert.Equal(t, "unknown member", memberErr.SecondaryError()) }) } + +func TestCheckMemberAccess(t *testing.T) { + + t.Parallel() + + t.Run("composite, field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + var x: [Int] + init() { + self.x = [] + } + } + + fun test() { + let test = Test() + var x: [Int] = test.x + } + `) + + require.NoError(t, err) + }) + + t.Run("composite, function", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + access(all) fun foo(): Int { + return 1 + } + } + + fun test() { + let test = Test() + var foo: (fun(): Int) = test.foo + } + `) + + require.NoError(t, err) + }) + + t.Run("composite reference, array field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + var x: [Int] + init() { + self.x = [] + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: &[Int] = testRef.x + } + `) + + require.NoError(t, err) + }) + + t.Run("composite reference, optional field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + var x: [Int]? + init() { + self.x = [] + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: &[Int]? = testRef.x + } + `) + + require.NoError(t, err) + }) + + t.Run("composite reference, primitive field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + var x: Int + init() { + self.x = 1 + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: Int = testRef.x + } + `) + + require.NoError(t, err) + }) + + t.Run("composite reference, non-existing field", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test {} + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: Int = testRef.x + } + `) + + errs := RequireCheckerErrors(t, err, 1) + var memberErr *sema.NotDeclaredMemberError + require.ErrorAs(t, errs[0], &memberErr) + }) + + t.Run("composite reference, function", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct Test { + access(all) fun foo(): Int { + return 1 + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var foo: (fun(): Int) = testRef.foo + } + `) + + require.NoError(t, err) + }) + + t.Run("array, element", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let array: [[Int]] = [[1, 2]] + var x: [Int] = array[0] + } + `) + + require.NoError(t, err) + }) + + t.Run("array reference, element", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let array: [[Int]] = [[1, 2]] + let arrayRef = &array as &[[Int]] + var x: &[Int] = arrayRef[0] + } + `) + + require.NoError(t, err) + }) + + t.Run("array authorized reference, element", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + + fun test() { + let array: [[Int]] = [[1, 2]] + let arrayRef = &array as auth(A) &[[Int]] + + // Must be a. err: returns an unauthorized reference. + var x: auth(A) &[Int] = arrayRef[0] + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + }) + + t.Run("array reference, optional typed element", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let array: [[Int]?] = [[1, 2]] + let arrayRef = &array as &[[Int]?] + var x: &[Int]? = arrayRef[0] + } + `) + + require.NoError(t, err) + }) + + t.Run("array reference, primitive typed element", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let array: [Int] = [1, 2] + let arrayRef = &array as &[Int] + var x: Int = arrayRef[0] + } + `) + + require.NoError(t, err) + }) + + t.Run("dictionary, value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let dict: {String: {String: Int}} = {"one": {"two": 2}} + var x: {String: Int}? = dict["one"] + } + `) + + require.NoError(t, err) + }) + + t.Run("dictionary reference, value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let dict: {String: {String: Int} } = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}} + var x: &{String: Int}? = dictRef["one"] + } + `) + + require.NoError(t, err) + }) + + t.Run("dictionary authorized reference, value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + + fun test() { + let dict: {String: {String: Int} } = {"one": {"two": 2}} + let dictRef = &dict as auth(A) &{String: {String: Int}} + + // Must be a. err: returns an unauthorized reference. + var x: auth(A) &{String: Int}? = dictRef["one"] + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + }) + + t.Run("dictionary reference, optional typed value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let dict: {String: {String: Int}?} = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}?} + var x: (&{String: Int})?? = dictRef["one"] + } + `) + + require.NoError(t, err) + }) + + t.Run("dictionary reference, optional typed value, mismatch types", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let dict: {String: {String: Int}?} = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}?} + + // Must return an optional reference, not a reference to an optional + var x: &({String: Int}??) = dictRef["one"] + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + }) + + t.Run("dictionary reference, primitive typed value", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let dict: {String: Int} = {"one": 1} + let dictRef = &dict as &{String: Int} + var x: Int? = dictRef["one"] + } + `) + + require.NoError(t, err) + }) + + t.Run("resource reference, attachment", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + attachment A for R {} + + fun test() { + let r <- create R() + let rRef = &r as &R + + var a: &A? = rRef[A] + destroy r + } + `) + + require.NoError(t, err) + }) + + t.Run("entitlement map access", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement mapping M { + A -> B + } + + struct S { + access(M) let foo: [String] + init() { + self.foo = [] + } + } + + fun test() { + let s = S() + let sRef = &s as auth(A) &S + var foo: auth(B) &[String] = sRef.foo + } + `) + + require.NoError(t, err) + }) + + t.Run("entitlement map access nested", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement C + entitlement D + + entitlement mapping FooMapping { + A -> B + } + entitlement mapping BarMapping { + C -> D + } + struct Foo { + access(FooMapping) let bars: [Bar] + init() { + self.bars = [Bar()] + } + } + struct Bar { + access(BarMapping) let baz: Baz + init() { + self.baz = Baz() + } + } + struct Baz { + access(D) fun canOnlyCallOnAuthD() {} + } + fun test() { + let foo = Foo() + let fooRef = &foo as auth(A) &Foo + + let bazRef: &Baz = fooRef.bars[0].baz + + // Error: 'fooRef.bars[0].baz' returns an unauthorized reference + bazRef.canOnlyCallOnAuthD() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + invalidAccessError := &sema.InvalidAccessError{} + require.ErrorAs(t, errors[0], &invalidAccessError) + }) + + t.Run("entitlement map access nested", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + entitlement A + entitlement B + entitlement C + + entitlement mapping FooMapping { + A -> B + } + + entitlement mapping BarMapping { + B -> C + } + + struct Foo { + access(FooMapping) let bars: [Bar] + init() { + self.bars = [Bar()] + } + } + + struct Bar { + access(BarMapping) let baz: Baz + init() { + self.baz = Baz() + } + } + + struct Baz { + access(C) fun canOnlyCallOnAuthC() {} + } + + fun test() { + let foo = Foo() + let fooRef = &foo as auth(A) &Foo + + let barArrayRef: auth(B) &[Bar] = fooRef.bars + + // Must be a. err: returns an unauthorized reference. + let barRef: auth(B) &Bar = barArrayRef[0] + + let bazRef: auth(C) &Baz = barRef.baz + + bazRef.canOnlyCallOnAuthC() + } + `) + + errors := RequireCheckerErrors(t, err, 1) + typeMismatchError := &sema.TypeMismatchError{} + require.ErrorAs(t, errors[0], &typeMismatchError) + }) + + t.Run("anyresource swap on reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource Foo {} + + fun test() { + let dict: @{String: AnyResource} <- {"foo": <- create Foo(), "bar": <- create Foo()} + let dictRef = &dict as &{String: AnyResource} + + dictRef["foo"] <-> dictRef["bar"] + + destroy dict + } + `) + + errs := RequireCheckerErrors(t, err, 4) + assert.IsType(t, &sema.UnauthorizedReferenceAssignmentError{}, errs[0]) + assert.IsType(t, &sema.UnauthorizedReferenceAssignmentError{}, errs[1]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[2]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[3]) + }) + + t.Run("all member types", func(t *testing.T) { + t.Parallel() + + test := func(tt *testing.T, typeName string) { + code := fmt.Sprintf(` + struct Foo { + var a: %[1]s? + + init() { + self.a = nil + } + } + + struct Bar {} + + struct interface I {} + + fun test() { + let foo = Foo() + let fooRef = &foo as &Foo + var a: &%[1]s? = fooRef.a + }`, + + typeName, + ) + + _, err := ParseAndCheck(t, code) + require.NoError(t, err) + } + + types := []string{ + "Bar", + "{I}", + "AnyStruct", + "Block", + } + + // Test all built-in composite types + for i := interpreter.PrimitiveStaticTypeAuthAccount; i < interpreter.PrimitiveStaticType_Count; i++ { + semaType := i.SemaType() + types = append(types, semaType.QualifiedString()) + } + + for _, typeName := range types { + t.Run(typeName, func(t *testing.T) { + test(t, typeName) + }) + } + }) +} diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index e804153e9f..dc9527b1e8 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -790,10 +790,8 @@ func TestCheckReferenceIndexingIfReferencedIndexable(t *testing.T) { fun test() { let rs <- [<-create R()] let ref = &rs as &[R] - var other <- create R() - ref[0] <-> other + ref[0] destroy rs - destroy other } `) @@ -811,8 +809,7 @@ func TestCheckReferenceIndexingIfReferencedIndexable(t *testing.T) { fun test() { let s = [S()] let ref = &s as &[S] - var other = S() - ref[0] <-> other + ref[0] } `) @@ -820,7 +817,7 @@ func TestCheckReferenceIndexingIfReferencedIndexable(t *testing.T) { }) } -func TestCheckInvalidReferenceResourceLoss(t *testing.T) { +func TestCheckReferenceResourceLoss(t *testing.T) { t.Parallel() @@ -830,17 +827,15 @@ func TestCheckInvalidReferenceResourceLoss(t *testing.T) { fun test() { let rs <- [<-create R()] let ref = &rs as &[R] - ref[0] + ref[0] // This result in a reference, so no resource loss destroy rs } `) - errs := RequireCheckerErrors(t, err, 1) - - assert.IsType(t, &sema.ResourceLossError{}, errs[0]) + require.NoError(t, err) } -func TestCheckInvalidReferenceResourceLoss2(t *testing.T) { +func TestCheckInvalidReferenceResourceLoss(t *testing.T) { t.Parallel() @@ -1710,7 +1705,7 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { authAccount.save(<-[<-create R()], to: /storage/a) let collectionRef = authAccount.borrow<&[R]>(from: /storage/a)! - let ref = &collectionRef[0] as &R + let ref = collectionRef[0] let collection <- authAccount.load<@[R]>(from: /storage/a)! authAccount.save(<- collection, to: /storage/b) @@ -1838,8 +1833,8 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { access(all) fun test() { var r: @{UInt64: {UInt64: [R]}} <- {} let ref1 = (&r[0] as &{UInt64: [R]}?)! - let ref2 = (&ref1[0] as &[R]?)! - let ref3 = &ref2[0] as &R + let ref2 = ref1[0]! + let ref3 = ref2[0] ref3.a destroy r @@ -1858,7 +1853,7 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { require.NoError(t, err) }) - t.Run("ref to ref invalid", func(t *testing.T) { + t.Run("ref to ref invalid, index expr", func(t *testing.T) { t.Parallel() @@ -1867,8 +1862,8 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { access(all) fun test() { var r: @{UInt64: {UInt64: [R]}} <- {} let ref1 = (&r[0] as &{UInt64: [R]}?)! - let ref2 = (&ref1[0] as &[R]?)! - let ref3 = &ref2[0] as &R + let ref2 = ref1[0]! + let ref3 = ref2[0] destroy r ref3.a } @@ -1887,6 +1882,55 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { assert.ErrorAs(t, errors[0], &invalidatedRefError) }) + t.Run("ref to ref invalid, member expr", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, + ` + access(all) fun test() { + var r: @R1 <- create R1() + let ref1 = &r as &R1 + let ref2 = ref1.r2 + let ref3 = ref2.r3 + destroy r + ref3.a + } + + access(all) resource R1 { + access(all) let r2: @R2 + init() { + self.r2 <- create R2() + } + destroy() { + destroy self.r2 + } + } + + access(all) resource R2 { + access(all) let r3: @R3 + init() { + self.r3 <- create R3() + } + destroy() { + destroy self.r3 + } + } + + access(all) resource R3 { + access(all) let a: Int + init() { + self.a = 5 + } + } + `, + ) + + errors := RequireCheckerErrors(t, err, 1) + invalidatedRefError := &sema.InvalidatedResourceReferenceError{} + assert.ErrorAs(t, errors[0], &invalidatedRefError) + }) + t.Run("create ref with force expr", func(t *testing.T) { t.Parallel() @@ -2009,7 +2053,7 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { access(all) resource R { access(all) fun test() { if let storage = &Test.a[0] as &{UInt64: Test.R}? { - let nftRef = (&storage[0] as &Test.R?)! + let nftRef = storage[0]! nftRef } } @@ -2066,9 +2110,9 @@ func TestCheckInvalidatedReferenceUse(t *testing.T) { access(all) contract Test { access(all) resource R { access(all) fun test(packList: &[Test.R]) { - var i = 0; + var i = 0 while i < packList.length { - let pack = &packList[i] as &Test.R; + let pack = packList[i] pack i = i + 1 } @@ -2653,7 +2697,7 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { fun test() { let rs <- [<-create R()] - let ref = &rs as &[R] + let ref = &rs as auth(Mutate) &[R] let container <- [<-rs] ref.insert(at: 1, <-create R()) destroy container @@ -2674,7 +2718,7 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { fun test() { let rs <- [<-create R()] - let ref = &rs as &[R] + let ref = &rs as auth(Mutate) &[R] let container <- [<-rs] ref.append(<-create R()) destroy container @@ -2704,10 +2748,19 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 2) + errs := RequireCheckerErrors(t, err, 4) + invalidatedRefError := &sema.InvalidatedResourceReferenceError{} assert.ErrorAs(t, errs[0], &invalidatedRefError) - assert.ErrorAs(t, errs[1], &invalidatedRefError) + + unauthorizedReferenceAssignmentError := &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errs[1], &unauthorizedReferenceAssignmentError) + + assert.ErrorAs(t, errs[2], &invalidatedRefError) + + typeMismatchError := &sema.TypeMismatchError{} + assert.ErrorAs(t, errs[3], &typeMismatchError) + }) t.Run("resource array, remove", func(t *testing.T) { @@ -2719,7 +2772,7 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { fun test() { let rs <- [<-create R()] - let ref = &rs as &[R] + let ref = &rs as auth(Mutate) &[R] let container <- [<-rs] let r <- ref.remove(at: 0) destroy container @@ -2748,9 +2801,12 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) + errs := RequireCheckerErrors(t, err, 2) invalidatedRefError := &sema.InvalidatedResourceReferenceError{} assert.ErrorAs(t, errs[0], &invalidatedRefError) + + unauthorizedReferenceAssignmentError := &sema.UnauthorizedReferenceAssignmentError{} + assert.ErrorAs(t, errs[1], &unauthorizedReferenceAssignmentError) }) t.Run("resource dictionary, remove", func(t *testing.T) { @@ -2762,7 +2818,7 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { fun test() { let rs <- {0: <-create R()} - let ref = &rs as &{Int: R} + let ref = &rs as auth(Remove) &{Int: R} let container <- [<-rs] let r <- ref.remove(key: 0) destroy container @@ -2774,6 +2830,37 @@ func TestCheckReferenceUseAfterCopy(t *testing.T) { invalidatedRefError := &sema.InvalidatedResourceReferenceError{} assert.ErrorAs(t, errs[0], &invalidatedRefError) }) + + t.Run("attachments", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + attachment A for R { + access(all) var id: UInt8 + init() { + self.id = 1 + } + } + + fun test() { + let r <- create R() + let r2 <- attach A() to <-r + + let a = r2[A]! + destroy r2 + + // Access attachment ref, after destroying the resource + a.id + } + `) + + errs := RequireCheckerErrors(t, err, 1) + invalidatedRefError := &sema.InvalidatedResourceReferenceError{} + assert.ErrorAs(t, errs[0], &invalidatedRefError) + }) } func TestCheckResourceReferenceMethodInvocationAfterMove(t *testing.T) { diff --git a/runtime/tests/checker/resources_test.go b/runtime/tests/checker/resources_test.go index 249b7f3d4b..0155571c93 100644 --- a/runtime/tests/checker/resources_test.go +++ b/runtime/tests/checker/resources_test.go @@ -9447,6 +9447,24 @@ func TestCheckConditionalResourceCreationAndReturn(t *testing.T) { require.NoError(t, err) } +func TestCheckIndexExpressionResourceLoss(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun test() { + let rs <- [<-create R()] + rs[0] + destroy rs + } + `) + + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.ResourceLossError{}, errs[0]) +} + func TestCheckResourceWithFunction(t *testing.T) { t.Parallel() diff --git a/runtime/tests/checker/swap_test.go b/runtime/tests/checker/swap_test.go index adf4a2579b..ccb93a38ad 100644 --- a/runtime/tests/checker/swap_test.go +++ b/runtime/tests/checker/swap_test.go @@ -39,9 +39,9 @@ func TestCheckInvalidUnknownDeclarationSwap(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) - + errs := RequireCheckerErrors(t, err, 2) assert.IsType(t, &sema.NotDeclaredError{}, errs[0]) + assert.IsType(t, &sema.NotDeclaredError{}, errs[1]) } func TestCheckInvalidLeftConstantSwap(t *testing.T) { @@ -105,9 +105,10 @@ func TestCheckInvalidTypesSwap(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) + errs := RequireCheckerErrors(t, err, 2) assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) } func TestCheckInvalidTypesSwap2(t *testing.T) { @@ -122,9 +123,10 @@ func TestCheckInvalidTypesSwap2(t *testing.T) { } `) - errs := RequireCheckerErrors(t, err, 1) + errs := RequireCheckerErrors(t, err, 2) assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) } func TestCheckInvalidSwapTargetExpressionLeft(t *testing.T) { diff --git a/runtime/tests/interpreter/array_test.go b/runtime/tests/interpreter/array_test.go new file mode 100644 index 0000000000..165872c980 --- /dev/null +++ b/runtime/tests/interpreter/array_test.go @@ -0,0 +1,134 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInterpretArrayFunctionEntitlements(t *testing.T) { + + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Mutate) &[String] + + // Public functions + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 1, upTo: 1) + arrayRef.concat(["hello"]) + + // Insertable functions + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + + // Removable functions + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as &[String] + + // Public functions + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 1, upTo: 1) + arrayRef.concat(["hello"]) + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let array: [String] = ["foo", "bar"] + + fun test() { + var arrayRef = &array as auth(Insert) &[String] + + // Public functions + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 1, upTo: 1) + arrayRef.concat(["hello"]) + + // Insertable functions + arrayRef.append("baz") + arrayRef.appendAll(["baz"]) + arrayRef.insert(at:0, "baz") + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let array: [String] = ["foo", "bar", "baz"] + + fun test() { + var arrayRef = &array as auth(Remove) &[String] + + // Public functions + arrayRef.contains("hello") + arrayRef.firstIndex(of: "hello") + arrayRef.slice(from: 1, upTo: 1) + arrayRef.concat(["hello"]) + + // Removable functions + arrayRef.remove(at: 1) + arrayRef.removeFirst() + arrayRef.removeLast() + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) +} diff --git a/runtime/tests/interpreter/attachments_test.go b/runtime/tests/interpreter/attachments_test.go index 5f95fa552f..9d70555169 100644 --- a/runtime/tests/interpreter/attachments_test.go +++ b/runtime/tests/interpreter/attachments_test.go @@ -1455,7 +1455,7 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { fun test(): UInt8 { let r <- create R() let r2 <- attach A() to <-r - let a = r2[A]! + let a = returnSameRef(r2[A]!) // Move the resource after taking a reference to the attachment. @@ -1467,7 +1467,11 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { // Access the attachment filed from the previous reference. return a.id - }`, + } + + access(all) fun returnSameRef(_ ref: &A): &A { + return ref + }`, sema.Config{ AttachmentsEnabled: true, }, @@ -1491,13 +1495,17 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { fun test() { let r <- create R() let r2 <- attach A() to <-r - let a = r2[A]! + let a = returnSameRef(r2[A]!) destroy r2 let i = a.foo() } - `, sema.Config{ - AttachmentsEnabled: true, - }, + + access(all) fun returnSameRef(_ ref: &A): &A { + return ref + }`, + sema.Config{ + AttachmentsEnabled: true, + }, ) _, err := inter.Invoke("test") @@ -1534,7 +1542,7 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { } fun test(): UInt8 { let r2 <- create R2(r: <-attach A() to <-create R()) - let a = r2.r[A]! + let a = returnSameRef(r2.r[A]!) // Move the resource after taking a reference to the attachment. // Then update the field of the attachment. @@ -1545,7 +1553,11 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { // Access the attachment filed from the previous reference. return a.id - }`, + } + + access(all) fun returnSameRef(_ ref: &A): &A { + return ref + }`, sema.Config{ AttachmentsEnabled: true, }, @@ -1624,14 +1636,17 @@ func TestInterpretAttachmentResourceReferenceInvalidation(t *testing.T) { } fun test() { let r2 <- create R2(r: <-attach A() to <-create R()) - let a = r2.r[A]! + let a = returnSameRef(r2.r[A]!) destroy r2 let i = a.foo() } - - `, sema.Config{ - AttachmentsEnabled: true, - }, + + access(all) fun returnSameRef(_ ref: &A): &A { + return ref + }`, + sema.Config{ + AttachmentsEnabled: true, + }, ) _, err := inter.Invoke("test") diff --git a/runtime/tests/interpreter/container_mutation_test.go b/runtime/tests/interpreter/container_mutation_test.go index e558539c5a..e39dc4a578 100644 --- a/runtime/tests/interpreter/container_mutation_test.go +++ b/runtime/tests/interpreter/container_mutation_test.go @@ -288,7 +288,7 @@ func TestArrayMutation(t *testing.T) { inter := parseCheckAndInterpret(t, ` fun test() { let names: [AnyStruct] = ["foo", "bar"] as [String] - let namesRef = &names as &[AnyStruct] + let namesRef = &names as auth(Mutate) &[AnyStruct] namesRef[0] = 5 } `) @@ -667,7 +667,7 @@ func TestDictionaryMutation(t *testing.T) { inter := parseCheckAndInterpret(t, ` fun test() { let names: {String: AnyStruct} = {"foo": "bar"} as {String: String} - let namesRef = &names as &{String: AnyStruct} + let namesRef = &names as auth(Mutate) &{String: AnyStruct} namesRef["foo"] = 5 } `) diff --git a/runtime/tests/interpreter/dictionary_test.go b/runtime/tests/interpreter/dictionary_test.go new file mode 100644 index 0000000000..dd07b40234 --- /dev/null +++ b/runtime/tests/interpreter/dictionary_test.go @@ -0,0 +1,118 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInterpretDictionaryFunctionEntitlements(t *testing.T) { + + t.Parallel() + + t.Run("mutable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + + // Public functions + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + + // Insertable functions + dictionaryRef.insert(key: "three", "baz") + + // Removable functions + dictionaryRef.remove(key: "foo") + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("non auth reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as &{String: String} + + // Public functions + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("insertable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + + // Public functions + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + + // Insertable functions + dictionaryRef.insert(key: "three", "baz") + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("removable reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let dictionary: {String: String} = {"one" : "foo", "two" : "bar"} + + fun test() { + var dictionaryRef = &dictionary as auth(Mutate) &{String: String} + + // Public functions + dictionaryRef.containsKey("foo") + dictionaryRef.forEachKey(fun(key: String): Bool {return true} ) + + // Removable functions + dictionaryRef.remove(key: "foo") + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) +} diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 6ea3e671b6..99f4d03a61 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -21,6 +21,7 @@ package interpreter_test import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/onflow/cadence/runtime/common" @@ -2704,3 +2705,194 @@ func TestInterpretEntitlementSetEquality(t *testing.T) { require.False(t, two.Equal(one)) }) } + +func TestInterpretBuiltinEntitlements(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + access(Mutate) fun foo() {} + access(Insert) fun bar() {} + access(Remove) fun baz() {} + } + + fun main() { + let s = S() + let mutableRef = &s as auth(Mutate) &S + let insertableRef = &s as auth(Insert) &S + let removableRef = &s as auth(Remove) &S + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) +} + +func TestInterpretIdentityMapping(t *testing.T) { + + t.Parallel() + + t.Run("owned value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + // OK: Must return an unauthorized ref + let resultRef1: &AnyStruct = s.foo() + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) + }) + + t.Run("unauthorized ref", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let ref = &s as &S + + // OK: Must return an unauthorized ref + let resultRef1: &AnyStruct = ref.foo() + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) + }) + + t.Run("basic entitled ref", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let mutableRef = &s as auth(Mutate) &S + let ref1: auth(Mutate) &AnyStruct = mutableRef.foo() + + let insertableRef = &s as auth(Insert) &S + let ref2: auth(Insert) &AnyStruct = insertableRef.foo() + + let removableRef = &s as auth(Remove) &S + let ref3: auth(Remove) &AnyStruct = removableRef.foo() + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) + }) + + t.Run("entitlement set ref", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + access(Identity) fun foo(): auth(Identity) &AnyStruct { + let a: AnyStruct = "hello" + return &a as auth(Identity) &AnyStruct + } + } + + fun main() { + let s = S() + + let ref1 = &s as auth(Insert | Remove) &S + let resultRef1: auth(Insert | Remove) &AnyStruct = ref1.foo() + + let ref2 = &s as auth(Insert, Remove) &S + let resultRef2: auth(Insert, Remove) &AnyStruct = ref2.foo() + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) + }) + + t.Run("owned value, with entitlements", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement A + entitlement B + entitlement C + + struct X { + access(A | B) var s: String + init() { + self.s = "hello" + } + access(C) fun foo() {} + } + + struct Y { + + // Reference + access(Identity) var x1: auth(Identity) &X + + // Optional reference + access(Identity) var x2: auth(Identity) &X? + + // Function returning a reference + access(Identity) fun getX(): auth(Identity) &X { + let x = X() + return &x as auth(Identity) &X + } + + // Function returning an optional reference + access(Identity) fun getOptionalX(): auth(Identity) &X? { + let x: X? = X() + return &x as auth(Identity) &X? + } + + init() { + let x = X() + self.x1 = &x as auth(A, B, C) &X + self.x2 = nil + } + } + + fun main() { + let y = Y() + + let ref1: auth(A, B, C) &X = y.x1 + + let ref2: auth(A, B, C) &X? = y.x2 + + let ref3: auth(A, B, C) &X = y.getX() + + let ref4: auth(A, B, C) &X? = y.getOptionalX() + } + `) + + _, err := inter.Invoke("main") + assert.NoError(t, err) + }) +} diff --git a/runtime/tests/interpreter/member_test.go b/runtime/tests/interpreter/member_test.go index 15af95ae85..9e9af3e31d 100644 --- a/runtime/tests/interpreter/member_test.go +++ b/runtime/tests/interpreter/member_test.go @@ -19,6 +19,7 @@ package interpreter_test import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -557,3 +558,591 @@ func TestInterpretMemberAccessType(t *testing.T) { }) }) } + +func TestInterpretMemberAccess(t *testing.T) { + + t.Parallel() + + t.Run("composite, field", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + var x: [Int] + init() { + self.x = [] + } + } + + fun test(): [Int] { + let test = Test() + return test.x + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite, function", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + access(all) fun foo(): Int { + return 1 + } + } + + fun test() { + let test = Test() + var foo: (fun(): Int) = test.foo + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite reference, field", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + var x: [Int] + init() { + self.x = [] + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: &[Int] = testRef.x + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite reference, optional field", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + var x: [Int]? + init() { + self.x = [] + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: &[Int]? = testRef.x + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite reference, primitive field", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + var x: Int + init() { + self.x = 1 + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var x: Int = testRef.x + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite reference, function", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + access(all) fun foo(): Int { + return 1 + } + } + + fun test() { + let test = Test() + let testRef = &test as &Test + var foo: (fun(): Int) = testRef.foo + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("resource reference, nested", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource Foo { + var bar: @Bar + init() { + self.bar <- create Bar() + } + destroy() { + destroy self.bar + } + } + + resource Bar { + var baz: @Baz + init() { + self.baz <- create Baz() + } + destroy() { + destroy self.baz + } + } + + resource Baz { + var x: &[Int] + init() { + self.x = &[] as &[Int] + } + } + + fun test() { + let foo <- create Foo() + let fooRef = &foo as &Foo + + // Nested container fields must return references + var barRef: &Bar = fooRef.bar + var bazRef: &Baz = fooRef.bar.baz + + // Reference typed field should return as is (no double reference must be created) + var x: &[Int] = fooRef.bar.baz.x + + destroy foo + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("composite reference, anystruct typed field, with reference value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Test { + var x: AnyStruct + init() { + var s = "hello" + self.x = &s as &String + } + } + + fun test():&AnyStruct { + let test = Test() + let testRef = &test as &Test + return testRef.x + } + `) + + result, err := inter.Invoke("test") + require.NoError(t, err) + + require.IsType(t, &interpreter.EphemeralReferenceValue{}, result) + ref := result.(*interpreter.EphemeralReferenceValue) + + // Must only have one level of references. + require.IsType(t, &interpreter.StringValue{}, ref.Value) + }) + + t.Run("array, element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let array: [[Int]] = [[1, 2]] + var x: [Int] = array[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array reference, element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let array: [[Int]] = [[1, 2]] + let arrayRef = &array as &[[Int]] + var x: &[Int] = arrayRef[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array authorized reference, element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement A + + fun test() { + let array: [[Int]] = [[1, 2]] + let arrayRef = &array as auth(A) &[[Int]] + + // Must return an unauthorized reference. + var x: &[Int] = arrayRef[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array reference, element, in assignment", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let array: [[Int]] = [[1, 2]] + let arrayRef = &array as &[[Int]] + var x: &[Int] = &[] as &[Int] + x = arrayRef[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array reference, optional typed element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let array: [[Int]?] = [[1, 2]] + let arrayRef = &array as &[[Int]?] + var x: &[Int]? = arrayRef[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array reference, primitive typed element", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let array: [Int] = [1, 2] + let arrayRef = &array as &[Int] + var x: Int = arrayRef[0] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("array reference, anystruct typed element, with reference value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test(): &AnyStruct { + var s = "hello" + let array: [AnyStruct] = [&s as &String] + let arrayRef = &array as &[AnyStruct] + return arrayRef[0] + } + `) + + result, err := inter.Invoke("test") + require.NoError(t, err) + + require.IsType(t, &interpreter.EphemeralReferenceValue{}, result) + ref := result.(*interpreter.EphemeralReferenceValue) + + // Must only have one level of references. + require.IsType(t, &interpreter.StringValue{}, ref.Value) + }) + + t.Run("dictionary, value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let dict: {String: {String: Int}} = {"one": {"two": 2}} + var x: {String: Int}? = dict["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("dictionary reference, value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let dict: {String: {String: Int} } = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}} + var x: &{String: Int}? = dictRef["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("dictionary authorized reference, value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement A + + fun test() { + let dict: {String: {String: Int} } = {"one": {"two": 2}} + let dictRef = &dict as auth(A) &{String: {String: Int}} + + // Must return an unauthorized reference. + var x: &{String: Int}? = dictRef["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("dictionary reference, value, in assignment", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let dict: {String: {String: Int} } = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}} + var x: &{String: Int}? = &{} as &{String: Int} + x = dictRef["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("dictionary reference, optional typed value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let dict: {String: {String: Int}?} = {"one": {"two": 2}} + let dictRef = &dict as &{String: {String: Int}?} + var x: (&{String: Int})?? = dictRef["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("dictionary reference, primitive typed value", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test() { + let dict: {String: Int} = {"one": 1} + let dictRef = &dict as &{String: Int} + var x: Int? = dictRef["one"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("resource reference, attachment", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + + attachment A for R {} + + fun test() { + let r <- create R() + let rRef = &r as &R + + var a: &A? = rRef[A] + destroy r + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("attachment nested member", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + + attachment A for R { + var foo: Foo + init() { + self.foo = Foo() + } + + access(all) fun getNestedMember(): [Int] { + return self.foo.array + } + } + + struct Foo { + var array: [Int] + init() { + self.array = [] + } + } + + fun test() { + let r <- attach A() to <- create R() + let rRef = &r as &R + + var a: &A? = rRef[A] + + var array = a!.getNestedMember() + + destroy r + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("anystruct swap on reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct Foo { + var array: [Int] + init() { + self.array = [] + } + } + + fun test() { + let dict: {String: AnyStruct} = {"foo": Foo(), "bar": Foo()} + let dictRef = &dict as auth(Mutate) &{String: AnyStruct} + + dictRef["foo"] <-> dictRef["bar"] + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("entitlement map access on field", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + entitlement A + entitlement B + entitlement mapping M { + A -> B + } + + struct S { + access(M) let foo: [String] + init() { + self.foo = [] + } + } + + fun test() { + let s = S() + let sRef = &s as auth(A) &S + var foo: auth(B) &[String] = sRef.foo + } + `) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("all member types", func(t *testing.T) { + t.Parallel() + + test := func(tt *testing.T, typeName string) { + code := fmt.Sprintf(` + struct Foo { + var a: %[1]s? + + init() { + self.a = nil + } + } + + struct Bar {} + + struct interface I {} + + fun test() { + let foo = Foo() + let fooRef = &foo as &Foo + var a: &%[1]s? = fooRef.a + }`, + + typeName, + ) + + inter := parseCheckAndInterpret(t, code) + + _, err := inter.Invoke("test") + require.NoError(t, err) + } + + types := []string{ + "Bar", + "{I}", + "[Int]", + "{Bool: String}", + "AnyStruct", + "Block", + } + + // Test all built-in composite types + for i := interpreter.PrimitiveStaticTypeAuthAccount; i < interpreter.PrimitiveStaticType_Count; i++ { + semaType := i.SemaType() + types = append(types, semaType.QualifiedString()) + } + + for _, typeName := range types { + t.Run(typeName, func(t *testing.T) { + test(t, typeName) + }) + } + }) +} diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index ee9559c14c..39a5446a7b 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -113,7 +113,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test(): Int { let dict: {Int: &S1} = {} - let dictRef = &dict as &{Int: &AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: &AnyStruct} let s2 = S2() dictRef[0] = &s2 as &AnyStruct @@ -148,7 +148,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test(): Int { let dict: {Int: S1} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = S2() @@ -186,7 +186,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test(): Int { let dict: {Int: &S1} = {} - let dictRef = &dict as &{Int: &AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: &AnyStruct} let s2 = S2() dictRef[0] = &s2 as &AnyStruct @@ -225,7 +225,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test(): Int { let dict: {Int: S1} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = S2() @@ -267,7 +267,7 @@ func TestInterpretContainerVariance(t *testing.T) { let s2 = S2() - let dictRef = &dict as &{Int: &AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: &AnyStruct} dictRef[0] = &s2 as &AnyStruct dict.values[0].value = 1 @@ -308,7 +308,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test() { let dict: {Int: S1} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = S2() @@ -340,7 +340,7 @@ func TestInterpretContainerVariance(t *testing.T) { let s2 = S2() - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = s2 let x = dict.values[0] @@ -369,7 +369,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test(): Int { let dict: {Int: fun(): Int} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = f2 @@ -393,7 +393,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test() { let dict: {Int: [UInt8]} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = "not an [UInt8] array, but a String" @@ -417,7 +417,7 @@ func TestInterpretContainerVariance(t *testing.T) { fun test() { let dict: {Int: [UInt8]} = {} - let dictRef = &dict as &{Int: AnyStruct} + let dictRef = &dict as auth(Mutate) &{Int: AnyStruct} dictRef[0] = "not an [UInt8] array, but a String" @@ -632,11 +632,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { } } - fun test(target: &[R]) { + fun test(target: auth(Mutate) &[R]) { target.append(<- create R()) // Take reference while in the account - let ref = &target[0] as &R + let ref = target[0] // Move the resource out of the account onto the stack let movedR <- target.remove(at: 0) @@ -662,7 +662,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { ) arrayRef := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, + interpreter.NewEntitlementSetAuthorization( + nil, + []common.TypeID{"Mutate"}, + sema.Conjunction, + ), array, &sema.VariableSizedType{ Type: rType, @@ -734,11 +738,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { } } - fun test(target1: &[R], target2: &[R]) { + fun test(target1: auth(Mutate) &[R], target2: auth(Mutate) &[R]) { target1.append(<- create R()) // Take reference while in the account_1 - let ref = &target1[0] as &R + let ref = target1[0] // Move the resource out of the account_1 into the account_2 target2.append(<- target1.remove(at: 0)) @@ -762,7 +766,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { ) arrayRef1 := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, + interpreter.NewEntitlementSetAuthorization( + nil, + []common.TypeID{"Mutate"}, + sema.Conjunction, + ), array1, &sema.VariableSizedType{ Type: rType, @@ -781,7 +789,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { ) arrayRef2 := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, + interpreter.NewEntitlementSetAuthorization( + nil, + []common.TypeID{"Mutate"}, + sema.Conjunction, + ), array2, &sema.VariableSizedType{ Type: rType, @@ -810,11 +822,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { } } - fun test(target: &[R]): Int { + fun test(target: auth(Mutate) &[R]): Int { target.append(<- create R()) // Take reference while in the account - let ref = &target[0] as &R + let ref = target[0] // Move the resource out of the account onto the stack. This should invalidate the reference. let movedR <- target.remove(at: 0) @@ -847,7 +859,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { ) arrayRef := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, + interpreter.NewEntitlementSetAuthorization( + nil, + []common.TypeID{"Mutate"}, + sema.Conjunction, + ), array, &sema.VariableSizedType{ Type: rType, @@ -918,11 +934,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { var ref2: &R? = nil var ref3: &R? = nil - fun setup(collection: &[R]) { + fun setup(collection: auth(Mutate) &[R]) { collection.append(<- create R()) // Take reference while in the account - ref1 = &collection[0] as &R + ref1 = collection[0] // Move the resource out of the account onto the stack. This should invalidate ref1. let movedR <- collection.remove(at: 0) @@ -937,7 +953,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { collection.append(<- movedR) // Take another reference - ref3 = &collection[1] as &R + ref3 = collection[1] } fun getRef1Id(): Int { @@ -972,7 +988,11 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { ) arrayRef := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, + interpreter.NewEntitlementSetAuthorization( + nil, + []common.TypeID{"Mutate"}, + sema.Conjunction, + ), array, &sema.VariableSizedType{ Type: rType, @@ -1247,7 +1267,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { var dict2 <- dict // Access the inner moved resource - var fooRef = &dictRef["levelTwo"] as &Foo? + var fooRef = dictRef["levelTwo"] destroy dict2 } @@ -1280,7 +1300,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { var array2 <- array // Access the inner moved resource - var fooRef = &arrayRef[0] as &Foo + var fooRef = arrayRef[0] destroy array2 } diff --git a/runtime/tests/interpreter/resources_test.go b/runtime/tests/interpreter/resources_test.go index 86d5a62747..615dd62c0d 100644 --- a/runtime/tests/interpreter/resources_test.go +++ b/runtime/tests/interpreter/resources_test.go @@ -24,8 +24,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/atree" - "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/tests/checker" . "github.com/onflow/cadence/runtime/tests/utils" @@ -145,80 +143,6 @@ func TestInterpretImplicitResourceRemovalFromContainer(t *testing.T) { ) }) - t.Run("reference, shift statement, member expression", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - resource R2 { - let value: String - - init() { - self.value = "test" - } - } - - resource R1 { - var r2: @R2? - - init() { - self.r2 <- nil - } - - destroy() { - destroy self.r2 - } - } - - fun createR1(): @R1 { - return <- create R1() - } - - fun test(r1: &R1): String? { - r1.r2 <-! create R2() - // The second assignment should not lead to the resource being cleared, - // it must be fully moved out of this container before, - // not just assigned to the new variable - let optR2 <- r1.r2 <- nil - let value = optR2?.value - destroy optR2 - return value - } - `) - - r1, err := inter.Invoke("createR1") - require.NoError(t, err) - - r1 = r1.Transfer( - inter, - interpreter.EmptyLocationRange, - atree.Address{1}, - false, - nil, - nil, - ) - - r1Type := checker.RequireGlobalType(t, inter.Program.Elaboration, "R1") - - ref := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, - r1, - r1Type, - ) - - value, err := inter.Invoke("test", ref) - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredSomeValueNonCopying( - interpreter.NewUnmeteredStringValue("test"), - ), - value, - ) - }) - t.Run("resource, if-let statement, member expression", func(t *testing.T) { t.Parallel() @@ -273,82 +197,6 @@ func TestInterpretImplicitResourceRemovalFromContainer(t *testing.T) { ) }) - t.Run("reference, if-let statement, member expression", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - resource R2 { - let value: String - - init() { - self.value = "test" - } - } - - resource R1 { - var r2: @R2? - - init() { - self.r2 <- nil - } - - destroy() { - destroy self.r2 - } - } - - fun createR1(): @R1 { - return <- create R1() - } - - fun test(r1: &R1): String? { - r1.r2 <-! create R2() - // The second assignment should not lead to the resource being cleared, - // it must be fully moved out of this container before, - // not just assigned to the new variable - if let r2 <- r1.r2 <- nil { - let value = r2.value - destroy r2 - return value - } - return nil - } - `) - - r1, err := inter.Invoke("createR1") - require.NoError(t, err) - - r1 = r1.Transfer( - inter, - interpreter.EmptyLocationRange, - atree.Address{1}, - false, - nil, - nil, - ) - - r1Type := checker.RequireGlobalType(t, inter.Program.Elaboration, "R1") - - ref := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, - r1, - r1Type, - ) - - value, err := inter.Invoke("test", ref) - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredSomeValueNonCopying( - interpreter.NewUnmeteredStringValue("test"), - ), - value, - ) - }) - t.Run("resource, shift statement, index expression", func(t *testing.T) { t.Parallel() @@ -400,88 +248,6 @@ func TestInterpretImplicitResourceRemovalFromContainer(t *testing.T) { ) }) - t.Run("reference, shift statement, index expression", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - resource R2 { - let value: String - - init() { - self.value = "test" - } - } - - resource R1 { - access(all) var r2s: @{Int: R2} - - access(all) fun setR2(i: Int, r: @R2) { - self.r2s[i] <-! r - } - - access(all) fun move(i: Int, r: @R2?): @R2? { - let optR2 <- self.r2s[i] <- r - return <- optR2 - } - - init() { - self.r2s <- {} - } - - destroy() { - destroy self.r2s - } - } - - fun createR1(): @R1 { - return <- create R1() - } - - fun test(r1: &R1): String? { - r1.setR2(i: 0, r: <- create R2()) - // The second assignment should not lead to the resource being cleared, - // it must be fully moved out of this container before, - // not just assigned to the new variable - let optR2 <- r1.move(i: 0, r: nil) - let value = optR2?.value - destroy optR2 - return value - } - `) - - r1, err := inter.Invoke("createR1") - require.NoError(t, err) - - r1 = r1.Transfer( - inter, - interpreter.EmptyLocationRange, - atree.Address{1}, - false, - nil, - nil, - ) - - r1Type := checker.RequireGlobalType(t, inter.Program.Elaboration, "R1") - - ref := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, - r1, - r1Type, - ) - value, err := inter.Invoke("test", ref) - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredSomeValueNonCopying( - interpreter.NewUnmeteredStringValue("test"), - ), - value, - ) - }) - t.Run("resource, if-let statement, index expression", func(t *testing.T) { t.Parallel() @@ -535,82 +301,6 @@ func TestInterpretImplicitResourceRemovalFromContainer(t *testing.T) { value, ) }) - - t.Run("reference, if-let statement, index expression", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - resource R2 { - let value: String - - init() { - self.value = "test" - } - } - - resource R1 { - var r2s: @{Int: R2} - - init() { - self.r2s <- {} - } - - destroy() { - destroy self.r2s - } - } - - fun createR1(): @R1 { - return <- create R1() - } - - fun test(r1: &R1): String? { - r1.r2s[0] <-! create R2() - // The second assignment should not lead to the resource being cleared, - // it must be fully moved out of this container before, - // not just assigned to the new variable - if let r2 <- r1.r2s[0] <- nil { - let value = r2.value - destroy r2 - return value - } - return nil - } - `) - - r1, err := inter.Invoke("createR1") - require.NoError(t, err) - - r1 = r1.Transfer( - inter, - interpreter.EmptyLocationRange, - atree.Address{1}, - false, - nil, - nil, - ) - - r1Type := checker.RequireGlobalType(t, inter.Program.Elaboration, "R1") - - ref := interpreter.NewUnmeteredEphemeralReferenceValue( - interpreter.UnauthorizedAccess, - r1, - r1Type, - ) - - value, err := inter.Invoke("test", ref) - require.NoError(t, err) - - AssertValuesEqual( - t, - inter, - interpreter.NewUnmeteredSomeValueNonCopying( - interpreter.NewUnmeteredStringValue("test"), - ), - value, - ) - }) } func TestInterpretInvalidatedResourceValidation(t *testing.T) { @@ -2373,9 +2063,9 @@ func TestInterpretOptionalResourceReference(t *testing.T) { fun test() { account.save(<-{0 : <-create R()}, to: /storage/x) - let collection = account.borrow<&{Int: R}>(from: /storage/x)! + let collection = account.borrow(from: /storage/x)! - let resourceRef = (&collection[0] as &R?)! + let resourceRef = collection[0]! let token <- collection.remove(key: 0) let x = resourceRef.id @@ -2411,9 +2101,9 @@ func TestInterpretArrayOptionalResourceReference(t *testing.T) { fun test() { account.save(<-[<-create R()], to: /storage/x) - let collection = account.borrow<&[R?]>(from: /storage/x)! + let collection = account.borrow(from: /storage/x)! - let resourceRef = (&collection[0] as &R?)! + let resourceRef = collection[0]! let token <- collection.remove(at: 0) let x = resourceRef.id diff --git a/runtime/tests/interpreter/string_test.go b/runtime/tests/interpreter/string_test.go index 44d3f0ee16..89ddce3e86 100644 --- a/runtime/tests/interpreter/string_test.go +++ b/runtime/tests/interpreter/string_test.go @@ -36,7 +36,7 @@ func TestInterpretRecursiveValueString(t *testing.T) { inter := parseCheckAndInterpret(t, ` fun test(): AnyStruct { let map: {String: AnyStruct} = {} - let mapRef = &map as &{String: AnyStruct} + let mapRef = &map as auth(Mutate) &{String: AnyStruct} mapRef["mapRef"] = mapRef return map }