diff --git a/tests/parser/test_call_graph_stability.py b/tests/parser/test_call_graph_stability.py new file mode 100644 index 0000000000..6785169ba3 --- /dev/null +++ b/tests/parser/test_call_graph_stability.py @@ -0,0 +1,68 @@ +import random +import string + +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +import vyper.ast as vy_ast +from vyper.compiler.phases import CompilerData + + +# random names for functions +@settings(max_examples=20, deadline=None) +@given( + st.lists( + st.tuples( + st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), + st.text(alphabet=string.ascii_lowercase, min_size=1), + ), + unique_by=lambda x: x[1], # unique on function name + min_size=1, + max_size=10, + ) +) +@pytest.mark.fuzzing +def test_call_graph_stability_fuzz(funcs): + def generate_func_def(mutability, func_name, i): + return f""" +@internal +{mutability} +def {func_name}() -> uint256: + return {i} + """ + + func_defs = "\n".join(generate_func_def(m, s, i) for i, (m, s) in enumerate(funcs)) + + for _ in range(10): + func_names = [f for (_, f) in funcs] + random.shuffle(func_names) + + self_calls = "\n".join(f" self.{f}()" for f in func_names) + code = f""" +{func_defs} + +@external +def foo(): +{self_calls} + """ + t = CompilerData(code) + + # check the .called_functions data structure on foo() directly + foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] + foo_t = foo._metadata["type"] + assert [f.name for f in foo_t.called_functions] == func_names + + # now for sanity, ensure the order that the function definitions appear + # in the IR is the same as the order of the calls + sigs = t.function_signatures + del sigs["foo"] + ir = t.ir_runtime + ir_funcs = [] + # search for function labels + for d in ir.args: # currently: (seq ... (seq (label foo ...)) ...) + if d.value == "seq" and d.args[0].value == "label": + r = d.args[0].args[0].value + if isinstance(r, str) and r.startswith("internal"): + ir_funcs.append(r) + assert ir_funcs == [f.internal_function_label for f in sigs.values()] diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 99a1957fd7..75fa3a1214 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -1,7 +1,7 @@ import re import warnings from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args @@ -28,7 +28,7 @@ from vyper.semantics.types.shortcuts import UINT256_T from vyper.semantics.types.subscriptable import TupleT from vyper.semantics.types.utils import type_from_abi, type_from_annotation -from vyper.utils import keccak256 +from vyper.utils import OrderedSet, keccak256 class ContractFunctionT(VyperType): @@ -89,7 +89,7 @@ def __init__( self.nonreentrant = nonreentrant # a list of internal functions this function calls - self.called_functions: Set["ContractFunctionT"] = set() + self.called_functions = OrderedSet() # special kwargs that are allowed in call site self.call_site_kwargs = { diff --git a/vyper/utils.py b/vyper/utils.py index 37a3f13b3d..2440117d0c 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -11,6 +11,19 @@ from vyper.exceptions import DecimalOverrideException, InvalidLiteral +class OrderedSet(dict): + """ + a minimal "ordered set" class. this is needed in some places + because, while dict guarantees you can recover insertion order + vanilla sets do not. + no attempt is made to fully implement the set API, will add + functionality as needed. + """ + + def add(self, item): + self[item] = None + + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): if name == "prec":