Skip to content

Commit

Permalink
Support execute transaction from Trezor (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 authored Jan 17, 2024
1 parent 0c1b04b commit 5b89e01
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 10 deletions.
3 changes: 1 addition & 2 deletions safe_cli/operators/hw_wallets/hw_wallet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
from abc import ABC, abstractmethod

from eth_typing import HexStr
from web3.types import TxParams

from .constants import BIP32_ETH_PATTERN, BIP32_LEGACY_LEDGER_PATTERN
Expand Down Expand Up @@ -50,7 +49,7 @@ def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes:
@abstractmethod
def get_signed_raw_transaction(
self, tx_parameters: TxParams, chain_id: int
) -> HexStr:
) -> bytes:
"""
:param chain_id:
Expand Down
7 changes: 4 additions & 3 deletions safe_cli/operators/hw_wallets/ledger_wallet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from eth_typing import ChecksumAddress, HexStr
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from ledgerblue.Dongle import Dongle
from ledgereth import create_transaction, sign_typed_data_draft
from ledgereth.accounts import get_account_by_path
Expand Down Expand Up @@ -54,7 +55,7 @@ def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes:
@raise_ledger_exception_as_hw_wallet_exception
def get_signed_raw_transaction(
self, tx_parameters: TxParams, chain_id: int
) -> HexStr:
) -> bytes:
"""
:param chain_id:
Expand All @@ -73,4 +74,4 @@ def get_signed_raw_transaction(
sender_path=self.derivation_path,
dongle=self.dongle,
)
return signed_transaction.raw_transaction()
return HexBytes(signed_transaction.raw_transaction())
81 changes: 77 additions & 4 deletions safe_cli/operators/hw_wallets/trezor_wallet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from functools import lru_cache

from eth_typing import ChecksumAddress, HexStr
import rlp
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from trezorlib import tools
from trezorlib.client import TrezorClient, get_default_client
from trezorlib.ethereum import get_address, sign_typed_data_hash
from trezorlib.ethereum import (
get_address,
sign_tx,
sign_tx_eip1559,
sign_typed_data_hash,
)
from trezorlib.ui import ClickUI
from web3.types import TxParams

Expand Down Expand Up @@ -51,11 +58,77 @@ def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes:
)
return signed.signature

def get_signed_raw_transaction(self, tx_parameters: TxParams) -> HexStr:
@raise_trezor_exception_as_hw_wallet_exception
def get_signed_raw_transaction(
self, tx_parameters: TxParams, chain_id: int
) -> bytes:
"""
:param chain_id:
:param tx_parameters:
:return: raw transaction signed
"""
raise NotImplementedError
address_n = tools.parse_path(self.derivation_path)
if tx_parameters.get("maxPriorityFeePerGas"):
# EIP1559
v, r, s = sign_tx_eip1559(
self.client,
n=address_n,
nonce=tx_parameters["nonce"],
gas_limit=tx_parameters["gas"],
to=tx_parameters["to"],
value=tx_parameters["value"],
data=HexBytes(tx_parameters["data"]),
chain_id=chain_id,
max_gas_fee=tx_parameters.get("maxFeePerGas"),
max_priority_fee=tx_parameters.get("maxPriorityFeePerGas"),
)

encoded_transaction = (
"0x02"
+ rlp.encode(
[
chain_id,
tx_parameters["nonce"],
tx_parameters.get("maxPriorityFeePerGas"),
tx_parameters.get("maxFeePerGas"),
tx_parameters["gas"],
HexBytes(tx_parameters["to"]),
tx_parameters["value"],
HexBytes(tx_parameters["data"]),
[],
v,
HexBytes(r),
HexBytes(s),
]
).hex()
)
else:
# Legacy transaction
v, r, s = sign_tx(
self.client,
n=address_n,
nonce=tx_parameters["nonce"],
gas_price=tx_parameters["gasPrice"],
gas_limit=tx_parameters["gas"],
to=tx_parameters["to"],
value=tx_parameters["value"],
data=HexBytes(tx_parameters.get("data")),
chain_id=chain_id,
)

encoded_transaction = rlp.encode(
[
tx_parameters["nonce"],
tx_parameters["gasPrice"],
tx_parameters["gas"],
HexBytes(tx_parameters["to"]),
tx_parameters["value"],
HexBytes(tx_parameters["data"]),
v,
HexBytes(r),
HexBytes(s),
]
).hex()

return HexBytes(encoded_transaction)
1 change: 0 additions & 1 deletion safe_cli/operators/safe_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def load_hw_wallet(
if (
not self.default_sender
and not self.hw_wallet_manager.sender
and hw_wallet_type == HwWalletType.LEDGER
and balance > 0
):
self.hw_wallet_manager.set_sender(hw_wallet_type, derivation_path)
Expand Down
102 changes: 102 additions & 0 deletions tests/test_trezor_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock

from eth_account import Account
from hexbytes import HexBytes
from trezorlib.client import TrezorClient
from trezorlib.exceptions import Cancelled, OutdatedFirmwareError, PinException
from trezorlib.messages import EthereumTypedDataSignature
Expand Down Expand Up @@ -149,3 +150,104 @@ def test_sign_typed_hash(
)
signature = trezor_wallet.sign_typed_hash(encode_hash[1], encode_hash[2])
self.assertEqual(expected_signature, signature)

@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.sign_tx",
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.sign_tx_eip1559",
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_address",
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
autospec=True,
)
def test_get_signed_raw_transaction(
self,
mock_trezor_client: MagicMock,
mock_get_address: MagicMock,
mock_sign_tx_eip1559: MagicMock,
mock_sign_tx: MagicMock,
):
owner = Account.create()
to = Account.create()
transport_mock = MagicMock(auto_spec=True)
mock_trezor_client.return_value = TrezorClient(
transport_mock, ui=ClickUI(), _init_device=False
)
mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False)
mock_get_address.return_value = owner.address
trezor_wallet = TrezorWallet("44'/60'/0'/0")

safe = self.deploy_test_safe(
owners=[owner.address],
threshold=1,
initial_funding_wei=self.w3.to_wei(0.1, "ether"),
)
safe_tx = SafeTx(
self.ethereum_client,
safe.address,
to.address,
10,
b"",
0,
200000,
200000,
self.gas_price,
None,
None,
safe_nonce=0,
)
safe_tx.sign(owner.key)
# Legacy transaction
tx_parameters = {
"from": owner.address,
"gasPrice": safe_tx.w3.eth.gas_price,
"nonce": 0,
"gas": safe_tx.recommended_gas(),
}
safe_tx.tx = safe_tx.w3_tx.build_transaction(tx_parameters)
signed_fields = safe_tx.w3.eth.account.sign_transaction(
safe_tx.tx, private_key=owner.key
)

mock_sign_tx.return_value = (
HexBytes(signed_fields.v),
HexBytes(signed_fields.r),
HexBytes(signed_fields.s),
)

raw_signed_tx = trezor_wallet.get_signed_raw_transaction(
safe_tx.tx, safe_tx.ethereum_client.get_chain_id()
) # return raw signed transaction
mock_sign_tx.assert_called_once()
self.assertEqual(signed_fields.rawTransaction, HexBytes(raw_signed_tx))

# EIP1559 transaction
tx_parameters = {
"from": owner.address,
"maxPriorityFeePerGas": safe_tx.w3.eth.gas_price,
"maxFeePerGas": safe_tx.w3.eth.gas_price,
"nonce": 1,
"gas": safe_tx.recommended_gas(),
}
safe_tx.tx = safe_tx.w3_tx.build_transaction(tx_parameters)
signed_fields = safe_tx.w3.eth.account.sign_transaction(
safe_tx.tx, private_key=owner.key
)

mock_sign_tx_eip1559.return_value = (
signed_fields.v,
HexBytes(signed_fields.r),
HexBytes(signed_fields.s),
)
raw_signed_tx = trezor_wallet.get_signed_raw_transaction(
safe_tx.tx, safe_tx.ethereum_client.get_chain_id()
) # return raw signed transaction
mock_sign_tx_eip1559.assert_called_once()
self.assertEqual(signed_fields.rawTransaction, HexBytes(raw_signed_tx))

0 comments on commit 5b89e01

Please sign in to comment.