Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix borrowing #782

Merged
merged 8 commits into from
Apr 12, 2021
16 changes: 16 additions & 0 deletions runtime/interpreter/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,19 @@ type MissingMemberValueError struct {
func (e MissingMemberValueError) Error() string {
return fmt.Sprintf("missing value for member `%s`", e.Name)
}

// InvocationArgumentTypeError

type InvocationArgumentTypeError struct {
Index int
ParameterType sema.Type
LocationRange
}

func (e InvocationArgumentTypeError) Error() string {
return fmt.Sprintf(
"invalid invocation with argument at index %d: expected %s",
e.Index,
e.ParameterType.QualifiedString(),
)
}
17 changes: 17 additions & 0 deletions runtime/interpreter/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ func (InterpretedFunctionValue) SetModified(_ bool) {
func (InterpretedFunctionValue) isFunctionValue() {}

func (f InterpretedFunctionValue) Invoke(invocation Invocation) Value {

// Check arguments' dynamic types match parameter types

for i, argument := range invocation.Arguments {
parameterType := f.Type.Parameters[i].TypeAnnotation.Type

argumentDynamicType := argument.DynamicType(f.Interpreter)

if !IsSubType(argumentDynamicType, parameterType) {
panic(InvocationArgumentTypeError{
Index: i,
ParameterType: parameterType,
LocationRange: invocation.GetLocationRange(),
})
}
}

return f.Interpreter.invokeInterpretedFunction(f, invocation)
}

Expand Down
90 changes: 55 additions & 35 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ func (interpreter *Interpreter) prepareInvoke(
GetLocationRange: ReturnEmptyLocationRange,
Interpreter: interpreter,
}

return functionValue.Invoke(invocation), nil
}

Expand Down Expand Up @@ -793,10 +794,12 @@ func (interpreter *Interpreter) recoverErrors(onError func(error)) {
// wrap the error with position information if needed

_, ok := err.(ast.HasPosition)
if !ok {
if !ok && interpreter.statement != nil {
r := ast.NewRangeFromPositioned(interpreter.statement)

err = PositionedError{
Err: err,
Range: ast.NewRangeFromPositioned(interpreter.statement),
Range: r,
}
}

Expand Down Expand Up @@ -2560,42 +2563,31 @@ func (interpreter *Interpreter) authAccountBorrowFunction(addressValue AddressVa
path := invocation.Arguments[0].(PathValue)
key := storageKey(path)

value := interpreter.readStored(address, key, false)

switch value := value.(type) {
case NilValue:
return value

case *SomeValue:

// If there is value stored for the given path,
// check that it satisfies the type given as the type argument.

typeParameterPair := invocation.TypeParameterTypes.Oldest()
if typeParameterPair == nil {
panic(errors.NewUnreachableError())
}

ty := typeParameterPair.Value
typeParameterPair := invocation.TypeParameterTypes.Oldest()
if typeParameterPair == nil {
panic(errors.NewUnreachableError())
}

referenceType := ty.(*sema.ReferenceType)
ty := typeParameterPair.Value

dynamicType := value.Value.DynamicType(interpreter)
if !IsSubType(dynamicType, referenceType.Type) {
return NilValue{}
}
referenceType := ty.(*sema.ReferenceType)

reference := &StorageReferenceValue{
Authorized: referenceType.Authorized,
TargetStorageAddress: address,
TargetKey: key,
}
reference := &StorageReferenceValue{
Authorized: referenceType.Authorized,
TargetStorageAddress: address,
TargetKey: key,
BorrowedType: referenceType.Type,
}

return NewSomeValueOwningNonCopying(reference)
// Attempt to dereference,
// which reads the stored value
// and performs a dynamic type check

default:
panic(errors.NewUnreachableError())
if reference.ReferencedValue(interpreter) == nil {
return NilValue{}
}

return NewSomeValueOwningNonCopying(reference)
})
}

Expand Down Expand Up @@ -2736,6 +2728,15 @@ func (interpreter *Interpreter) capabilityBorrowFunction(
Authorized: authorized,
TargetStorageAddress: address,
TargetKey: targetStorageKey,
BorrowedType: borrowType.Type,
}

// Attempt to dereference,
// which reads the stored value
// and performs a dynamic type check

if reference.ReferencedValue(interpreter) == nil {
return NilValue{}
}

return NewSomeValueOwningNonCopying(reference)
Expand Down Expand Up @@ -2765,17 +2766,36 @@ func (interpreter *Interpreter) capabilityCheckFunction(
panic(errors.NewUnreachableError())
}

targetStorageKey, _ :=
targetStorageKey, authorized :=
interpreter.getCapabilityFinalTargetStorageKey(
addressValue,
pathValue,
borrowType,
invocation.GetLocationRange,
)

isValid := targetStorageKey != ""
if targetStorageKey == "" {
return BoolValue(false)
}

address := addressValue.ToAddress()

reference := &StorageReferenceValue{
Authorized: authorized,
TargetStorageAddress: address,
TargetKey: targetStorageKey,
BorrowedType: borrowType.Type,
}

// Attempt to dereference,
// which reads the stored value
// and performs a dynamic type check

if reference.ReferencedValue(interpreter) == nil {
return BoolValue(false)
}

return BoolValue(isValid)
return BoolValue(true)
},
)
}
Expand Down
15 changes: 14 additions & 1 deletion runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -6522,6 +6522,7 @@ type StorageReferenceValue struct {
Authorized bool
TargetStorageAddress common.Address
TargetKey string
BorrowedType sema.Type
}

func (*StorageReferenceValue) IsValue() {}
Expand Down Expand Up @@ -6558,6 +6559,7 @@ func (v *StorageReferenceValue) Copy() Value {
Authorized: v.Authorized,
TargetStorageAddress: v.TargetStorageAddress,
TargetKey: v.TargetKey,
BorrowedType: v.BorrowedType,
}
}

Expand All @@ -6581,9 +6583,20 @@ func (*StorageReferenceValue) SetModified(_ bool) {
func (v *StorageReferenceValue) ReferencedValue(interpreter *Interpreter) *Value {
switch referenced := interpreter.readStored(v.TargetStorageAddress, v.TargetKey, false).(type) {
case *SomeValue:
return &referenced.Value
value := referenced.Value

if v.BorrowedType != nil {
dynamicType := value.DynamicType(interpreter)
if !IsSubType(dynamicType, v.BorrowedType) {
return nil
}
}

return &value

case NilValue:
return nil

default:
panic(errors.NewUnreachableError())
}
Expand Down
53 changes: 44 additions & 9 deletions runtime/stdlib/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,48 @@ var cryptoContractType = func() *sema.CompositeType {
return variable.Type.(*sema.CompositeType)
}()

const cryptoSignatureVerifierImplIdentifier = "SignatureVerifierImpl"

const cryptoHasherImplIdentifier = "HasherImpl"

func registerCheckerElaborationCompositeType(
identifier string,
explicitConformances []*sema.InterfaceType,
) {
typeID := CryptoChecker.Location.TypeID(identifier)
CryptoChecker.Elaboration.CompositeTypes[typeID] = &sema.CompositeType{
Location: CryptoChecker.Location,
Identifier: identifier,
Kind: common.CompositeKindStructure,
ExplicitInterfaceConformances: explicitConformances,
}
}

func init() {
signatureVerifierVariable, ok := CryptoChecker.Elaboration.GlobalTypes.Get("SignatureVerifier")
if !ok {
panic(errors2.NewUnreachableError())
}

registerCheckerElaborationCompositeType(
cryptoSignatureVerifierImplIdentifier,
[]*sema.InterfaceType{
signatureVerifierVariable.Type.(*sema.InterfaceType),
},
)

hasherVariable, ok := CryptoChecker.Elaboration.GlobalTypes.Get("Hasher")
if !ok {
panic(errors2.NewUnreachableError())
}
registerCheckerElaborationCompositeType(
cryptoHasherImplIdentifier,
[]*sema.InterfaceType{
hasherVariable.Type.(*sema.InterfaceType),
},
)
}

var cryptoContractInitializerTypes = func() (result []sema.Type) {
result = make([]sema.Type, len(cryptoContractType.ConstructorParameters))
for i, parameter := range cryptoContractType.ConstructorParameters {
Expand Down Expand Up @@ -144,13 +186,9 @@ func newCryptoContractVerifySignatureFunction(signatureVerifier CryptoSignatureV
}

func newCryptoContractSignatureVerifier(signatureVerifier CryptoSignatureVerifier) *interpreter.CompositeValue {
implIdentifier := CryptoChecker.Location.
QualifiedIdentifier(cryptoContractInitializerTypes[0].ID()) +
"Impl"

result := interpreter.NewCompositeValue(
CryptoChecker.Location,
implIdentifier,
cryptoSignatureVerifierImplIdentifier,
common.CompositeKindStructure,
nil,
nil,
Expand Down Expand Up @@ -185,13 +223,10 @@ func newCryptoContractHashFunction(hasher CryptoHasher) interpreter.FunctionValu
}

func newCryptoContractHasher(hasher CryptoHasher) *interpreter.CompositeValue {
implIdentifier := CryptoChecker.Location.
QualifiedIdentifier(cryptoContractInitializerTypes[1].ID()) +
"Impl"

result := interpreter.NewCompositeValue(
CryptoChecker.Location,
implIdentifier,
cryptoHasherImplIdentifier,
common.CompositeKindStructure,
nil,
nil,
Expand Down
Loading