Skip to content

Commit

Permalink
Bring KMS client and multiclient over to chainlink
Browse files Browse the repository at this point in the history
  • Loading branch information
ogtownsend committed Sep 17, 2024
1 parent a7348f0 commit 8e3dcdc
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 0 deletions.
186 changes: 186 additions & 0 deletions integration-tests/deployment/kms/evm_kmsclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package kms

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/asn1"
"encoding/hex"
"fmt"
"math/big"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
)

var (
secp256k1N = crypto.S256().Params().N
secp256k1HalfN = new(big.Int).Div(secp256k1N, big.NewInt(2))
)

// See https://docs.aws.amazon.com/kms/latest/APIReference/API_GetPublicKey.html#API_GetPublicKey_ResponseSyntax
// and https://datatracker.ietf.org/doc/html/rfc5280 for why we need to unpack the KMS public key.
type asn1SubjectPublicKeyInfo struct {
AlgorithmIdentifier asn1AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}

type asn1AlgorithmIdentifier struct {
Algorithm asn1.ObjectIdentifier
Parameters asn1.ObjectIdentifier
}

// See https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/ for why we
// need to manually prep the signature for Ethereum.
type asn1ECDSASig struct {
R asn1.RawValue
S asn1.RawValue
}

type KMSClient interface {
GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Sign(input *kms.SignInput) (*kms.SignOutput, error)
}

type evmKMSClient struct {
Client KMSClient
KeyID string
}

func NewEVMKMSClient(client KMSClient, keyID string) *evmKMSClient {
return &evmKMSClient{
Client: client,
KeyID: keyID,
}
}

func (c *evmKMSClient) GetKMSTransactOpts(ctx context.Context, chainID *big.Int) (*bind.TransactOpts, error) {
ecdsaPublicKey, err := c.GetECDSAPublicKey()
if err != nil {
return nil, err
}

pubKeyBytes := secp256k1.S256().Marshal(ecdsaPublicKey.X, ecdsaPublicKey.Y)
keyAddr := crypto.PubkeyToAddress(*ecdsaPublicKey)
if chainID == nil {
return nil, fmt.Errorf("chainID is required")
}
signer := types.LatestSignerForChainID(chainID)

signerFn := func(address common.Address, tx *types.Transaction) (*types.Transaction, error) {
if address != keyAddr {
return nil, bind.ErrNotAuthorized
}

txHashBytes := signer.Hash(tx).Bytes()

mType := kms.MessageTypeDigest
algo := kms.SigningAlgorithmSpecEcdsaSha256
signOutput, err := c.Client.Sign(
&kms.SignInput{
KeyId: &c.KeyID,
SigningAlgorithm: &algo,
MessageType: &mType,
Message: txHashBytes,
})
if err != nil {
return nil, fmt.Errorf("failed to call kms.Sign() on transaction: %v", err)
}

ethSig, err := kmsToEthSig(signOutput.Signature, pubKeyBytes, txHashBytes)
if err != nil {
return nil, fmt.Errorf("failed to convert KMS signature to Ethereum signature: %v", err)
}

return tx.WithSignature(signer, ethSig)
}

return &bind.TransactOpts{
From: keyAddr,
Signer: signerFn,
Context: ctx,
}, nil
}

// GetECDSAPublicKey retrieves the public key from KMS and converts it to its ECDSA representation.
func (c *evmKMSClient) GetECDSAPublicKey() (*ecdsa.PublicKey, error) {
getPubKeyOutput, err := c.Client.GetPublicKey(&kms.GetPublicKeyInput{
KeyId: aws.String(c.KeyID),
})
if err != nil {
return nil, fmt.Errorf("can not get public key from KMS for KeyId=%s: %s", c.KeyID, err)
}

var asn1pubKeyInfo asn1SubjectPublicKeyInfo
_, err = asn1.Unmarshal(getPubKeyOutput.PublicKey, &asn1pubKeyInfo)
if err != nil {
return nil, fmt.Errorf("can not parse asn1 public key for KeyId=%s: %s", c.KeyID, err)
}

pubKey, err := crypto.UnmarshalPubkey(asn1pubKeyInfo.SubjectPublicKey.Bytes)
if err != nil {
return nil, fmt.Errorf("can not unmarshal public key bytes: %s", err)
}
return pubKey, nil
}

func kmsToEthSig(kmsSig, ecdsaPubKeyBytes, hash []byte) ([]byte, error) {
var asn1Sig asn1ECDSASig
_, err := asn1.Unmarshal(kmsSig, &asn1Sig)
if err != nil {
return nil, err
}

rBytes := asn1Sig.R.Bytes
sBytes := asn1Sig.S.Bytes

// Adjust S value from signature to match Eth standard.
// See: https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/
// "After we extract r and s successfully, we have to test if the value of s is greater than secp256k1n/2 as
// specified in EIP-2 and flip it if required."
sBigInt := new(big.Int).SetBytes(sBytes)
if sBigInt.Cmp(secp256k1HalfN) > 0 {
sBytes = new(big.Int).Sub(secp256k1N, sBigInt).Bytes()
}

return recoverEthSignature(ecdsaPubKeyBytes, hash, rBytes, sBytes)
}

// See: https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/
func recoverEthSignature(expectedPublicKeyBytes, txHash, r, s []byte) ([]byte, error) {
rsSig := append(padTo32Bytes(r), padTo32Bytes(s)...)
ethSig := append(rsSig, []byte{0}...)

recoveredPublicKeyBytes, err := crypto.Ecrecover(txHash, ethSig)
if err != nil {
return nil, fmt.Errorf("failing to call Ecrecover: %v", err)
}

if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) {
ethSig = append(rsSig, []byte{1}...)
recoveredPublicKeyBytes, err = crypto.Ecrecover(txHash, ethSig)
if err != nil {
return nil, fmt.Errorf("failing to call Ecrecover: %v", err)
}

if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) {
return nil, fmt.Errorf("can not reconstruct public key from sig")
}
}

return ethSig, nil
}

func padTo32Bytes(buffer []byte) []byte {
buffer = bytes.TrimLeft(buffer, "\x00")
for len(buffer) < 32 {
zeroBuf := []byte{0}
buffer = append(zeroBuf, buffer...)
}
return buffer
}
27 changes: 27 additions & 0 deletions integration-tests/deployment/kms/evm_kmsclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package kms

import (
"encoding/hex"
"testing"

"github.com/test-go/testify/require"
)

func TestKMSToEthSigConversion(t *testing.T) {
kmsSigBytes, err := hex.DecodeString("304402206168865941bafcae3a8cf8b26edbb5693d62222b2e54d962c1aabbeaddf33b6802205edc7f597d2bf2d1eaa14fc514a6202bafcffe52b13ae3fec00674d92a874b73")
require.NoError(t, err)
ecdsaPublicKeyBytes, err := hex.DecodeString("04a735e9e3cb526f83be23b03f1f5ae7788a8654e3f0fcfb4f978290de07ebd47da30eeb72e904fdd4a81b46e320908ff4345e119148f89c1f04674c14a506e24b")
require.NoError(t, err)
txHashBytes, err := hex.DecodeString("a2f037301e90f58c084fe4bec2eef14b26e620d6b6cb46051037d03b29ab7d9a")
require.NoError(t, err)
expectedEthSignBytes, err := hex.DecodeString("6168865941bafcae3a8cf8b26edbb5693d62222b2e54d962c1aabbeaddf33b685edc7f597d2bf2d1eaa14fc514a6202bafcffe52b13ae3fec00674d92a874b7300")
require.NoError(t, err)

actualEthSig, err := kmsToEthSig(
kmsSigBytes,
ecdsaPublicKeyBytes,
txHashBytes,
)
require.NoError(t, err)
require.Equal(t, expectedEthSignBytes, actualEthSig)
}
104 changes: 104 additions & 0 deletions integration-tests/deployment/multiclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package deployment

import (
"context"
"fmt"
"math/big"
"time"

"github.com/avast/retry-go/v4"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
)

const (
RPC_RETRY_ATTEMPTS = 10
RPC_RETRY_DELAY = 1000 * time.Millisecond
)

// MultiClient should comply with the coreenv.OnchainClient interface
var _ OnchainClient = &MultiClient{}

type MultiClient struct {
*ethclient.Client
backup []*ethclient.Client
}

type RPC struct {
RPCName string `toml:"rpc_name"`
HTTPURL string `toml:"http_url"`
WSURL string `toml:"ws_url"`
}

func NewMultiClient(rpcs []RPC) *MultiClient {
if len(rpcs) == 0 {
panic("No RPCs provided")
}
clients := make([]*ethclient.Client, 0, len(rpcs))
for _, rpc := range rpcs {
client, err := ethclient.Dial(rpc.HTTPURL)
if err != nil {
panic(err)
}
clients = append(clients, client)
}
return &MultiClient{
Client: clients[0],
backup: clients[1:],
}
}

func (mc *MultiClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
var receipt *types.Receipt
err := mc.retryWithBackups(func(client *ethclient.Client) error {
var err error
receipt, err = client.TransactionReceipt(ctx, txHash)
return err
})
return receipt, err
}

func (mc *MultiClient) SendTransaction(ctx context.Context, tx *types.Transaction) error {
return mc.retryWithBackups(func(client *ethclient.Client) error {
return client.SendTransaction(ctx, tx)
})
}

func (mc *MultiClient) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) {
var code []byte
err := mc.retryWithBackups(func(client *ethclient.Client) error {
var err error
code, err = client.CodeAt(ctx, account, blockNumber)
return err
})
return code, err
}

func (mc *MultiClient) NonceAt(ctx context.Context, account common.Address) (uint64, error) {
var count uint64
err := mc.retryWithBackups(func(client *ethclient.Client) error {
var err error
count, err = client.NonceAt(ctx, account, nil)
return err
})
return count, err
}

func (mc *MultiClient) retryWithBackups(op func(*ethclient.Client) error) error {
var err error
for _, client := range append([]*ethclient.Client{mc.Client}, mc.backup...) {
err2 := retry.Do(func() error {
err = op(client)
if err != nil {
fmt.Printf(" [MultiClient RPC] Retrying with new client, error: %v\n", err)
return err
}
return nil
}, retry.Attempts(RPC_RETRY_ATTEMPTS), retry.Delay(RPC_RETRY_DELAY))
if err2 == nil {
return nil
}
}
return err
}

0 comments on commit 8e3dcdc

Please sign in to comment.