diff --git a/core/vm/common.go b/core/vm/common.go index ae79afc35bb0..ec3aac3cd051 100644 --- a/core/vm/common.go +++ b/core/vm/common.go @@ -17,6 +17,8 @@ package vm import ( + "strings" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/holiman/uint256" @@ -94,3 +96,7 @@ func minInt(a int, b int) int { } return b } + +func contains(haystack []byte, needle []byte) bool { + return strings.Contains(string(haystack), string(needle)) +} diff --git a/core/vm/contracts.go b/core/vm/contracts.go index e0b24e5e1809..6ae5acd649a4 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1102,7 +1102,7 @@ func (e *verifyCiphertext) RequiredGas(input []byte) uint64 { func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) (ret []byte, err error) { // TODO: Accept a proof from `input` too ctHash := crypto.Keccak256Hash(input) - accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext{accessibleState.Interpreter().evm.depth, input} + accessibleState.Interpreter().verifiedCiphertexts[ctHash] = &verifiedCiphertext{accessibleState.Interpreter().evm.depth, input} return ctHash.Bytes(), nil } @@ -1140,11 +1140,9 @@ func (e *delegateCiphertext) Run(accessibleState PrecompileAccessibleState, call if len(input) != 32 { return nil, errors.New("invalid ciphertext handle") } - hash := common.BytesToHash(input) - ct, ok := accessibleState.Interpreter().verifiedCiphertexts[hash] + ct, ok := accessibleState.Interpreter().verifiedCiphertexts[common.BytesToHash(input)] if ok { ct.depth = minInt(ct.depth, accessibleState.Interpreter().evm.depth-1) - accessibleState.Interpreter().verifiedCiphertexts[hash] = ct return nil, nil } return nil, errors.New("unverified ciphertext handle") diff --git a/core/vm/instructions.go b/core/vm/instructions.go index d8a0b0b064ca..59c6613ff1fb 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -569,8 +569,7 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont if alreadyVerified { verifiedCt.depth = minInt(verifiedCt.depth, interpreter.evm.depth) } else { - verifiedCt.depth = interpreter.evm.depth - verifiedCt.ciphertext = ctBytes + verifiedCt = &verifiedCiphertext{interpreter.evm.depth, ctBytes} } interpreter.verifiedCiphertexts[val] = verifiedCt } @@ -941,9 +940,12 @@ func opReturn(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b offset, size := scope.Stack.pop(), scope.Stack.pop() ret := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) - // Remove all verified ciphertexts that have depth > current depth - 1 for key, verifiedCiphertext := range interpreter.verifiedCiphertexts { - if verifiedCiphertext.depth > interpreter.evm.depth-1 { + if contains(ret, key.Bytes()) { + // If a handle is returned, automatically make it available to the caller. + verifiedCiphertext.depth = minInt(verifiedCiphertext.depth, interpreter.evm.depth-1) + } else if verifiedCiphertext.depth > interpreter.evm.depth-1 { + // Remove any ciphertexts that are not delegated for use by the caller. delete(interpreter.verifiedCiphertexts, key) } } diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 7cd2869f0150..85c1c6c573a3 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -68,7 +68,7 @@ type EVMInterpreter struct { readOnly bool // Whether to throw on stateful modifications returnData []byte // Last CALL's return data for subsequent reuse - verifiedCiphertexts map[common.Hash]verifiedCiphertext // A map from a ciphertext hash to itself and stack depth at which it is verified + verifiedCiphertexts map[common.Hash]*verifiedCiphertext // A map from a ciphertext hash to itself and stack depth at which it is verified } // NewEVMInterpreter returns a new instance of the Interpreter. @@ -111,7 +111,7 @@ func NewEVMInterpreter(evm *EVM, cfg Config) *EVMInterpreter { return &EVMInterpreter{ evm: evm, cfg: cfg, - verifiedCiphertexts: make(map[common.Hash]verifiedCiphertext), + verifiedCiphertexts: make(map[common.Hash]*verifiedCiphertext), } } diff --git a/tests/solidity/zama/handles.sol b/tests/solidity/zama/handles.sol index 44fa73f38ff5..e72623fe0865 100644 --- a/tests/solidity/zama/handles.sol +++ b/tests/solidity/zama/handles.sol @@ -67,9 +67,11 @@ contract HandleOwner is Precompiles { handle = bogus_handle; } - // Returns the handle without delegation. Callers using it must fail. + // Returns the handle without delegation. Callers using it must succeed + // due to automatic delegation. function get_handle_without_delegate() public view returns (uint256) { - return handle; + uint256 h = handle; + return h; } // Returns the handle with delegation. Callers using it must succeed. @@ -82,6 +84,11 @@ contract HandleOwner is Precompiles { function callee_reencrypt() public view returns (uint256) { return callee.reencrypt(handle); } + + function load_handle_without_returning_it() public view returns(uint256) { + uint256 h = handle + 1; + return h; + } } contract Callee is Precompiles { @@ -97,13 +104,19 @@ contract Caller is Precompiles { owner = HandleOwner(owner_addr); } - // Fails, because the owner hasn't delegated. + // Succeeds, because we do automatic delegation on return. function reencrypt_without_delegate() public view returns (uint256) { return precompile_reencrypt(owner.get_handle_without_delegate()); } - // Succeeds, because the owner hasn't delegated. + // Succeeds, because there is an explicit delegate by the caller. function reencrypt_with_delegate() public view returns (uint256) { return precompile_reencrypt(owner.get_handle_with_delegate()); } + + // Fails, because the owner hasn't delegated, even though the handle is valid. + function reencrypt_with_a_valid_handle(uint256 handle) public view returns (uint256) { + owner.load_handle_without_returning_it(); + return precompile_reencrypt(handle); + } }