-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bring KMS client and multiclient over to chainlink
- Loading branch information
1 parent
a7348f0
commit 8e3dcdc
Showing
3 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
package kms | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/ecdsa" | ||
"encoding/asn1" | ||
"encoding/hex" | ||
"fmt" | ||
"math/big" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/service/kms" | ||
"github.com/ethereum/go-ethereum/accounts/abi/bind" | ||
"github.com/ethereum/go-ethereum/common" | ||
"github.com/ethereum/go-ethereum/core/types" | ||
"github.com/ethereum/go-ethereum/crypto" | ||
"github.com/ethereum/go-ethereum/crypto/secp256k1" | ||
) | ||
|
||
var ( | ||
secp256k1N = crypto.S256().Params().N | ||
secp256k1HalfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) | ||
) | ||
|
||
// See https://docs.aws.amazon.com/kms/latest/APIReference/API_GetPublicKey.html#API_GetPublicKey_ResponseSyntax | ||
// and https://datatracker.ietf.org/doc/html/rfc5280 for why we need to unpack the KMS public key. | ||
type asn1SubjectPublicKeyInfo struct { | ||
AlgorithmIdentifier asn1AlgorithmIdentifier | ||
SubjectPublicKey asn1.BitString | ||
} | ||
|
||
type asn1AlgorithmIdentifier struct { | ||
Algorithm asn1.ObjectIdentifier | ||
Parameters asn1.ObjectIdentifier | ||
} | ||
|
||
// See https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/ for why we | ||
// need to manually prep the signature for Ethereum. | ||
type asn1ECDSASig struct { | ||
R asn1.RawValue | ||
S asn1.RawValue | ||
} | ||
|
||
type KMSClient interface { | ||
GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) | ||
Sign(input *kms.SignInput) (*kms.SignOutput, error) | ||
} | ||
|
||
type evmKMSClient struct { | ||
Client KMSClient | ||
KeyID string | ||
} | ||
|
||
func NewEVMKMSClient(client KMSClient, keyID string) *evmKMSClient { | ||
return &evmKMSClient{ | ||
Client: client, | ||
KeyID: keyID, | ||
} | ||
} | ||
|
||
func (c *evmKMSClient) GetKMSTransactOpts(ctx context.Context, chainID *big.Int) (*bind.TransactOpts, error) { | ||
ecdsaPublicKey, err := c.GetECDSAPublicKey() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
pubKeyBytes := secp256k1.S256().Marshal(ecdsaPublicKey.X, ecdsaPublicKey.Y) | ||
keyAddr := crypto.PubkeyToAddress(*ecdsaPublicKey) | ||
if chainID == nil { | ||
return nil, fmt.Errorf("chainID is required") | ||
} | ||
signer := types.LatestSignerForChainID(chainID) | ||
|
||
signerFn := func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { | ||
if address != keyAddr { | ||
return nil, bind.ErrNotAuthorized | ||
} | ||
|
||
txHashBytes := signer.Hash(tx).Bytes() | ||
|
||
mType := kms.MessageTypeDigest | ||
algo := kms.SigningAlgorithmSpecEcdsaSha256 | ||
signOutput, err := c.Client.Sign( | ||
&kms.SignInput{ | ||
KeyId: &c.KeyID, | ||
SigningAlgorithm: &algo, | ||
MessageType: &mType, | ||
Message: txHashBytes, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to call kms.Sign() on transaction: %v", err) | ||
} | ||
|
||
ethSig, err := kmsToEthSig(signOutput.Signature, pubKeyBytes, txHashBytes) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to convert KMS signature to Ethereum signature: %v", err) | ||
} | ||
|
||
return tx.WithSignature(signer, ethSig) | ||
} | ||
|
||
return &bind.TransactOpts{ | ||
From: keyAddr, | ||
Signer: signerFn, | ||
Context: ctx, | ||
}, nil | ||
} | ||
|
||
// GetECDSAPublicKey retrieves the public key from KMS and converts it to its ECDSA representation. | ||
func (c *evmKMSClient) GetECDSAPublicKey() (*ecdsa.PublicKey, error) { | ||
getPubKeyOutput, err := c.Client.GetPublicKey(&kms.GetPublicKeyInput{ | ||
KeyId: aws.String(c.KeyID), | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("can not get public key from KMS for KeyId=%s: %s", c.KeyID, err) | ||
} | ||
|
||
var asn1pubKeyInfo asn1SubjectPublicKeyInfo | ||
_, err = asn1.Unmarshal(getPubKeyOutput.PublicKey, &asn1pubKeyInfo) | ||
if err != nil { | ||
return nil, fmt.Errorf("can not parse asn1 public key for KeyId=%s: %s", c.KeyID, err) | ||
} | ||
|
||
pubKey, err := crypto.UnmarshalPubkey(asn1pubKeyInfo.SubjectPublicKey.Bytes) | ||
if err != nil { | ||
return nil, fmt.Errorf("can not unmarshal public key bytes: %s", err) | ||
} | ||
return pubKey, nil | ||
} | ||
|
||
func kmsToEthSig(kmsSig, ecdsaPubKeyBytes, hash []byte) ([]byte, error) { | ||
var asn1Sig asn1ECDSASig | ||
_, err := asn1.Unmarshal(kmsSig, &asn1Sig) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
rBytes := asn1Sig.R.Bytes | ||
sBytes := asn1Sig.S.Bytes | ||
|
||
// Adjust S value from signature to match Eth standard. | ||
// See: https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/ | ||
// "After we extract r and s successfully, we have to test if the value of s is greater than secp256k1n/2 as | ||
// specified in EIP-2 and flip it if required." | ||
sBigInt := new(big.Int).SetBytes(sBytes) | ||
if sBigInt.Cmp(secp256k1HalfN) > 0 { | ||
sBytes = new(big.Int).Sub(secp256k1N, sBigInt).Bytes() | ||
} | ||
|
||
return recoverEthSignature(ecdsaPubKeyBytes, hash, rBytes, sBytes) | ||
} | ||
|
||
// See: https://aws.amazon.com/blogs/database/part2-use-aws-kms-to-securely-manage-ethereum-accounts/ | ||
func recoverEthSignature(expectedPublicKeyBytes, txHash, r, s []byte) ([]byte, error) { | ||
rsSig := append(padTo32Bytes(r), padTo32Bytes(s)...) | ||
ethSig := append(rsSig, []byte{0}...) | ||
|
||
recoveredPublicKeyBytes, err := crypto.Ecrecover(txHash, ethSig) | ||
if err != nil { | ||
return nil, fmt.Errorf("failing to call Ecrecover: %v", err) | ||
} | ||
|
||
if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) { | ||
ethSig = append(rsSig, []byte{1}...) | ||
recoveredPublicKeyBytes, err = crypto.Ecrecover(txHash, ethSig) | ||
if err != nil { | ||
return nil, fmt.Errorf("failing to call Ecrecover: %v", err) | ||
} | ||
|
||
if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) { | ||
return nil, fmt.Errorf("can not reconstruct public key from sig") | ||
} | ||
} | ||
|
||
return ethSig, nil | ||
} | ||
|
||
func padTo32Bytes(buffer []byte) []byte { | ||
buffer = bytes.TrimLeft(buffer, "\x00") | ||
for len(buffer) < 32 { | ||
zeroBuf := []byte{0} | ||
buffer = append(zeroBuf, buffer...) | ||
} | ||
return buffer | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package kms | ||
|
||
import ( | ||
"encoding/hex" | ||
"testing" | ||
|
||
"github.com/test-go/testify/require" | ||
) | ||
|
||
func TestKMSToEthSigConversion(t *testing.T) { | ||
kmsSigBytes, err := hex.DecodeString("304402206168865941bafcae3a8cf8b26edbb5693d62222b2e54d962c1aabbeaddf33b6802205edc7f597d2bf2d1eaa14fc514a6202bafcffe52b13ae3fec00674d92a874b73") | ||
require.NoError(t, err) | ||
ecdsaPublicKeyBytes, err := hex.DecodeString("04a735e9e3cb526f83be23b03f1f5ae7788a8654e3f0fcfb4f978290de07ebd47da30eeb72e904fdd4a81b46e320908ff4345e119148f89c1f04674c14a506e24b") | ||
require.NoError(t, err) | ||
txHashBytes, err := hex.DecodeString("a2f037301e90f58c084fe4bec2eef14b26e620d6b6cb46051037d03b29ab7d9a") | ||
require.NoError(t, err) | ||
expectedEthSignBytes, err := hex.DecodeString("6168865941bafcae3a8cf8b26edbb5693d62222b2e54d962c1aabbeaddf33b685edc7f597d2bf2d1eaa14fc514a6202bafcffe52b13ae3fec00674d92a874b7300") | ||
require.NoError(t, err) | ||
|
||
actualEthSig, err := kmsToEthSig( | ||
kmsSigBytes, | ||
ecdsaPublicKeyBytes, | ||
txHashBytes, | ||
) | ||
require.NoError(t, err) | ||
require.Equal(t, expectedEthSignBytes, actualEthSig) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package deployment | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"math/big" | ||
"time" | ||
|
||
"github.com/avast/retry-go/v4" | ||
"github.com/ethereum/go-ethereum/common" | ||
"github.com/ethereum/go-ethereum/core/types" | ||
"github.com/ethereum/go-ethereum/ethclient" | ||
) | ||
|
||
const ( | ||
RPC_RETRY_ATTEMPTS = 10 | ||
RPC_RETRY_DELAY = 1000 * time.Millisecond | ||
) | ||
|
||
// MultiClient should comply with the coreenv.OnchainClient interface | ||
var _ OnchainClient = &MultiClient{} | ||
|
||
type MultiClient struct { | ||
*ethclient.Client | ||
backup []*ethclient.Client | ||
} | ||
|
||
type RPC struct { | ||
RPCName string `toml:"rpc_name"` | ||
HTTPURL string `toml:"http_url"` | ||
WSURL string `toml:"ws_url"` | ||
} | ||
|
||
func NewMultiClient(rpcs []RPC) *MultiClient { | ||
if len(rpcs) == 0 { | ||
panic("No RPCs provided") | ||
} | ||
clients := make([]*ethclient.Client, 0, len(rpcs)) | ||
for _, rpc := range rpcs { | ||
client, err := ethclient.Dial(rpc.HTTPURL) | ||
if err != nil { | ||
panic(err) | ||
} | ||
clients = append(clients, client) | ||
} | ||
return &MultiClient{ | ||
Client: clients[0], | ||
backup: clients[1:], | ||
} | ||
} | ||
|
||
func (mc *MultiClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { | ||
var receipt *types.Receipt | ||
err := mc.retryWithBackups(func(client *ethclient.Client) error { | ||
var err error | ||
receipt, err = client.TransactionReceipt(ctx, txHash) | ||
return err | ||
}) | ||
return receipt, err | ||
} | ||
|
||
func (mc *MultiClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { | ||
return mc.retryWithBackups(func(client *ethclient.Client) error { | ||
return client.SendTransaction(ctx, tx) | ||
}) | ||
} | ||
|
||
func (mc *MultiClient) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) { | ||
var code []byte | ||
err := mc.retryWithBackups(func(client *ethclient.Client) error { | ||
var err error | ||
code, err = client.CodeAt(ctx, account, blockNumber) | ||
return err | ||
}) | ||
return code, err | ||
} | ||
|
||
func (mc *MultiClient) NonceAt(ctx context.Context, account common.Address) (uint64, error) { | ||
var count uint64 | ||
err := mc.retryWithBackups(func(client *ethclient.Client) error { | ||
var err error | ||
count, err = client.NonceAt(ctx, account, nil) | ||
return err | ||
}) | ||
return count, err | ||
} | ||
|
||
func (mc *MultiClient) retryWithBackups(op func(*ethclient.Client) error) error { | ||
var err error | ||
for _, client := range append([]*ethclient.Client{mc.Client}, mc.backup...) { | ||
err2 := retry.Do(func() error { | ||
err = op(client) | ||
if err != nil { | ||
fmt.Printf(" [MultiClient RPC] Retrying with new client, error: %v\n", err) | ||
return err | ||
} | ||
return nil | ||
}, retry.Attempts(RPC_RETRY_ATTEMPTS), retry.Delay(RPC_RETRY_DELAY)) | ||
if err2 == nil { | ||
return nil | ||
} | ||
} | ||
return err | ||
} |