Skip to content

Commit

Permalink
Merge pull request #3601 from onflow/supun/port-263
Browse files Browse the repository at this point in the history
Fix invocation boxing
  • Loading branch information
SupunS authored Oct 8, 2024
2 parents 70606a5 + 9f55a1d commit 9e12a05
Show file tree
Hide file tree
Showing 16 changed files with 741 additions and 257 deletions.
32 changes: 32 additions & 0 deletions runtime/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,38 @@ func TestRuntimePublicAccountKeys(t *testing.T) {
keys[keyIdx] = nil // no key should be passed to the callback twice
}
})

t.Run("keys.forEach, box and convert argument", func(t *testing.T) {
t.Parallel()

testEnv := initTestEnv(revokedAccountKeyA, accountKeyB)
test := accountKeyTestCase{
//language=Cadence
code: `
access(all)
fun main(): String? {
var res: String? = nil
// NOTE: The function has a parameter of type AccountKey? instead of just AccountKey
getAccount(0x02).keys.forEach(fun(key: AccountKey?): Bool {
// The map should call Optional.map, not fail,
// because path is AccountKey?, not AccountKey
res = key.map(fun(string: AnyStruct): String {
return "Optional.map"
})
return true
})
return res
}
`,
}

value, err := test.executeScript(testEnv.runtime, testEnv.runtimeInterface)
require.NoError(t, err)
utils.AssertEqualWithDiff(t,
cadence.NewOptional(cadence.String("Optional.map")),
value,
)
})
}

func TestRuntimeHashAlgorithm(t *testing.T) {
Expand Down
81 changes: 81 additions & 0 deletions runtime/capabilitycontrollers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2036,6 +2036,48 @@ func TestRuntimeCapabilityControllers(t *testing.T) {
nonDeploymentEventStrings(events),
)
})

t.Run("forEachController, box and convert argument", func(t *testing.T) {

t.Parallel()

err, _, _ := test(
t,
// language=cadence
`
import Test from 0x1
transaction {
prepare(signer: auth(Capabilities) &Account) {
let storagePath = /storage/r
// Arrange
signer.capabilities.storage.issue<&Test.R>(storagePath)
// Act
var res: String? = nil
signer.capabilities.storage.forEachController(
forPath: storagePath,
// NOTE: The function has a parameter of type &StorageCapabilityController?
// instead of just &StorageCapabilityController
fun (controller: &StorageCapabilityController?): Bool {
// The map should call Optional.map, not fail,
// because path is PublicPath?, not PublicPath
res = controller.map(fun(string: AnyStruct): String {
return "Optional.map"
})
return true
}
)
// Assert
assert(res == "Optional.map")
}
}
`,
)
require.NoError(t, err)
})
})

t.Run("Account.AccountCapabilities", func(t *testing.T) {
Expand Down Expand Up @@ -2606,6 +2648,45 @@ func TestRuntimeCapabilityControllers(t *testing.T) {
nonDeploymentEventStrings(events),
)
})

t.Run("forEachController, box and convert argument", func(t *testing.T) {

t.Parallel()

err, _, _ := test(
t,
// language=cadence
`
import Test from 0x1
transaction {
prepare(signer: auth(Capabilities) &Account) {
// Arrange
signer.capabilities.account.issue<&Account>()
// Act
var res: String? = nil
signer.capabilities.account.forEachController(
// NOTE: The function has a parameter of type &AccountCapabilityController?
// instead of just &AccountCapabilityController
fun (controller: &AccountCapabilityController?): Bool {
// The map should call Optional.map, not fail,
// because path is PublicPath?, not PublicPath
res = controller.map(fun(string: AnyStruct): String {
return "Optional.map"
})
return true
}
)
// Assert
assert(res == "Optional.map")
}
}
`,
)
require.NoError(t, err)
})
})

t.Run("StorageCapabilityController", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions runtime/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ func (e *interpreterEnvironment) newContractValueHandler() interpreter.ContractV
invocation.ConstructorArguments,
invocation.ArgumentTypes,
invocation.ParameterTypes,
invocation.ContractType,
invocationRange,
)
if err != nil {
Expand Down
19 changes: 0 additions & 19 deletions runtime/interpreter/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,25 +620,6 @@ func (e UseBeforeInitializationError) Error() string {
return fmt.Sprintf("member `%s` is used before it has been initialized", e.Name)
}

// InvocationArgumentTypeError
type InvocationArgumentTypeError struct {
LocationRange
ParameterType sema.Type
Index int
}

var _ errors.UserError = InvocationArgumentTypeError{}

func (InvocationArgumentTypeError) IsUserError() {}

func (e InvocationArgumentTypeError) Error() string {
return fmt.Sprintf(
"invalid invocation with argument at index %d: expected `%s`",
e.Index,
e.ParameterType.QualifiedString(),
)
}

// MemberAccessTypeError
type MemberAccessTypeError struct {
ExpectedType sema.Type
Expand Down
25 changes: 14 additions & 11 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4152,21 +4152,24 @@ func (interpreter *Interpreter) newStorageIterationFunction(
storageValue,
functionType,
func(_ *SimpleCompositeValue, invocation Invocation) Value {
interpreter := invocation.Interpreter
inter := invocation.Interpreter
locationRange := invocation.LocationRange

fn, ok := invocation.Arguments[0].(FunctionValue)
if !ok {
panic(errors.NewUnreachableError())
}

locationRange := invocation.LocationRange
inter := invocation.Interpreter
fnType := fn.FunctionType()
parameterTypes := fnType.ParameterTypes()
returnType := fnType.ReturnTypeAnnotation.Type

storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false)
if storageMap == nil {
// if nothing is stored, no iteration is required
return Void
}
storageIterator := storageMap.Iterator(interpreter)
storageIterator := storageMap.Iterator(inter)

invocationArgumentTypes := []sema.Type{pathType, sema.MetaType}

Expand All @@ -4178,7 +4181,7 @@ func (interpreter *Interpreter) newStorageIterationFunction(

for key, value := storageIterator.Next(); key != nil && value != nil; key, value = storageIterator.Next() {

staticType := value.StaticType(interpreter)
staticType := value.StaticType(inter)

// Perform a forced value de-referencing to see if the associated type is not broken.
// If broken, skip this value from the iteration.
Expand All @@ -4197,18 +4200,18 @@ func (interpreter *Interpreter) newStorageIterationFunction(
pathValue := NewPathValue(inter, domain, identifier)
runtimeType := NewTypeValue(inter, staticType)

subInvocation := NewInvocation(
inter,
nil,
nil,
nil,
result := inter.invokeFunctionValue(
fn,
[]Value{pathValue, runtimeType},
nil,
invocationArgumentTypes,
parameterTypes,
returnType,
nil,
locationRange,
)

shouldContinue, ok := fn.invoke(subInvocation).(BoolValue)
shouldContinue, ok := result.(BoolValue)
if !ok {
panic(errors.NewUnreachableError())
}
Expand Down
47 changes: 14 additions & 33 deletions runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ func (interpreter *Interpreter) checkMemberAccess(

targetStaticType := target.StaticType(interpreter)

if _, ok := expectedType.(*sema.OptionalType); ok {
if _, ok := targetStaticType.(*OptionalStaticType); !ok {
targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType)

panic(MemberAccessTypeError{
ExpectedType: expectedType,
ActualType: targetSemaType,
LocationRange: locationRange,
})
}
}

if !interpreter.IsSubTypeOfSemaType(targetStaticType, expectedType) {
targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType)

Expand Down Expand Up @@ -1207,6 +1219,7 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in
typeParameterTypes := invocationExpressionTypes.TypeArguments
argumentTypes := invocationExpressionTypes.ArgumentTypes
parameterTypes := invocationExpressionTypes.TypeParameterTypes
returnType := invocationExpressionTypes.ReturnType

// add the implicit argument to the end of the argument list, if it exists
if implicitArg != nil {
Expand All @@ -1222,45 +1235,13 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in
argumentExpressions,
argumentTypes,
parameterTypes,
returnType,
typeParameterTypes,
invocationExpression,
)

interpreter.reportInvokedFunctionReturn()

locationRange := LocationRange{
Location: interpreter.Location,
HasPosition: invocationExpression.InvokedExpression,
}

functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type

// Only convert and box.
// No need to transfer, since transfer would happen later, when the return value gets assigned.
//
// The conversion is needed because, the runtime function's return type could be a
// subtype of the invocation's return type.
// e.g:
// struct interface I {
// fun foo(): T?
// }
//
// struct S: I {
// fun foo(): T {...}
// }
//
// var i: {I} = S()
// return i.foo()?.bar
//
// Here runtime function's return type is `T`, but invocation's return type is `T?`.

resultValue = interpreter.ConvertAndBox(
locationRange,
resultValue,
functionReturnType,
invocationExpressionTypes.ReturnType,
)

// If this is invocation is optional chaining, wrap the result
// as an optional, as the result is expected to be an optional
if isOptionalChaining {
Expand Down
33 changes: 32 additions & 1 deletion runtime/interpreter/interpreter_invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (interpreter *Interpreter) InvokeFunctionValue(
arguments []Value,
argumentTypes []sema.Type,
parameterTypes []sema.Type,
returnType sema.Type,
invocationPosition ast.HasPosition,
) (
value Value,
Expand All @@ -47,6 +48,7 @@ func (interpreter *Interpreter) InvokeFunctionValue(
nil,
argumentTypes,
parameterTypes,
returnType,
nil,
invocationPosition,
), nil
Expand All @@ -58,6 +60,7 @@ func (interpreter *Interpreter) invokeFunctionValue(
expressions []ast.Expression,
argumentTypes []sema.Type,
parameterTypes []sema.Type,
returnType sema.Type,
typeParameterTypes *sema.TypeParameterTypeOrderedMap,
invocationPosition ast.HasPosition,
) Value {
Expand Down Expand Up @@ -123,7 +126,35 @@ func (interpreter *Interpreter) invokeFunctionValue(
locationRange,
)

return function.invoke(invocation)
resultValue := function.invoke(invocation)

functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type

// Only convert and box.
// No need to transfer, since transfer would happen later, when the return value gets assigned.
//
// The conversion is needed because, the runtime function's return type could be a
// subtype of the invocation's return type.
// e.g:
// struct interface I {
// fun foo(): T?
// }
//
// struct S: I {
// fun foo(): T {...}
// }
//
// var i: {I} = S()
// return i.foo()?.bar
//
// Here runtime function's return type is `T`, but invocation's return type is `T?`.

return interpreter.ConvertAndBox(
locationRange,
resultValue,
functionReturnType,
returnType,
)
}

func (interpreter *Interpreter) invokeInterpretedFunction(
Expand Down
Loading

0 comments on commit 9e12a05

Please sign in to comment.