Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT-#333] Allow assert_events to filter by event args #340

Merged
merged 1 commit into from
Sep 20, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions scenario_player/tasks/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def query_blockchain_events(


class AssertBlockchainEventsTask(Task):
""" Assert on blockchain events.

Required parameters:
- ``contract_name``
Which contract events to assert on. Example: ``TokenNetwork``
- ``event_name``
Contract specific event name to filter for.
- ``num_events``
The number of expected events.

Optional parameters:
- ``event_args``
A dictionary of event specific arguments that is used to further filter the found events.
This has a special handling for node addresses: If the name of an argument contains the
word ``participant`` an integer node index can be given instead of an ethereum address.

Example::

- assert_events:
contract_name: "TokenNetwork"
event_name: "ChannelClosed"
num_events: 1
event_args: {closing_participant: 1} # The 1 refers to scenario node index 1

"""

_name = "assert_events"

def __init__(
Expand All @@ -103,6 +129,12 @@ def __init__(
self.contract_name = config["contract_name"]
self.event_name = config["event_name"]
self.num_events = config["num_events"]
self.event_args: Dict[str, Any] = config.get("event_args", {}).copy()
for key, value in self.event_args.items():
if "participant" in key:
if isinstance(value, int) or (isinstance(value, str) and value.isnumeric()):
# Replace node index with eth address
self.event_args[key] = self._runner.get_node_address(int(value))

self.web3 = self._runner.client.web3

Expand Down Expand Up @@ -134,6 +166,12 @@ def _run(self, *args, **kwargs): # pylint: disable=unused-argument
# Filter matching events
events = [e for e in events if e["event"] == self.event_name]

if self.event_args:
event_args_items = self.event_args.items()
# Filter the events by the given event args.
# `.items()` produces a set like object which supports intersection (`&`)
events = [e for e in events if e["args"] and event_args_items & e["args"].items()]

# Raise exception when events do not match
if not self.num_events == len(events):
raise ScenarioAssertionError(
Expand Down