Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RSA/ECDSA performance #13

Merged
merged 1 commit into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions openssl/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ type ecdsaSignature struct {

type PrivateKeyECDSA struct {
// _pkey MUST NOT be accessed directly. Instead, use the withKey method.
_pkey *C.EVP_PKEY
_pkey C.GO_EVP_PKEY_PTR
}

func (k *PrivateKeyECDSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PrivateKeyECDSA) withKey(f func(*C.EVP_PKEY) C.int) C.int {
func (k *PrivateKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
defer runtime.KeepAlive(k)
return f(k._pkey)
}

type PublicKeyECDSA struct {
// _pkey MUST NOT be accessed directly. Instead, use the withKey method.
_pkey *C.EVP_PKEY
_pkey C.GO_EVP_PKEY_PTR
}

func (k *PublicKeyECDSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PublicKeyECDSA) withKey(f func(*C.EVP_PKEY) C.int) C.int {
func (k *PublicKeyECDSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
defer runtime.KeepAlive(k)
return f(k._pkey)
}
Expand Down Expand Up @@ -79,7 +79,7 @@ func NewPublicKeyECDSA(curve string, X, Y *big.Int) (*PublicKeyECDSA, error) {
return k, nil
}

func newECKey(curve string, X, Y, D *big.Int) (pkey *C.EVP_PKEY, err error) {
func newECKey(curve string, X, Y, D *big.Int) (pkey C.GO_EVP_PKEY_PTR, err error) {
var nid C.int
if nid, err = curveNID(curve); err != nil {
return nil, err
Expand Down
30 changes: 15 additions & 15 deletions openssl/evpkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func cryptoHashToMD(ch crypto.Hash) *C.EVP_MD {
return nil
}

func generateEVPPKey(id C.int, bits int, curve string) (*C.EVP_PKEY, error) {
func generateEVPPKey(id C.int, bits int, curve string) (C.GO_EVP_PKEY_PTR, error) {
if (bits == 0 && curve == "") || (bits != 0 && curve != "") {
return nil, fail("incorrect generateEVPPKey parameters")
}
Expand All @@ -89,20 +89,20 @@ func generateEVPPKey(id C.int, bits int, curve string) (*C.EVP_PKEY, error) {
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
}
}
var pkey *C.EVP_PKEY
var pkey C.GO_EVP_PKEY_PTR
if C.go_openssl_EVP_PKEY_keygen(ctx, &pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_keygen failed")
}
return pkey, nil
}

type withKeyFunc func(func(*C.EVP_PKEY) C.int) C.int
type initFunc func(*C.EVP_PKEY_CTX) C.int
type cryptFunc func(*C.EVP_PKEY_CTX, *C.uint8_t, *C.uint, *C.uint8_t, C.uint) C.int
type withKeyFunc func(func(C.GO_EVP_PKEY_PTR) C.int) C.int
type initFunc func(C.GO_EVP_PKEY_CTX_PTR) C.int
type cryptFunc func(C.GO_EVP_PKEY_CTX_PTR, *C.uint8_t, *C.uint, *C.uint8_t, C.uint) C.int

func setupEVP(withKey withKeyFunc, padding C.int,
h hash.Hash, label []byte, saltLen int, ch crypto.Hash,
init initFunc) (ctx *C.EVP_PKEY_CTX, err error) {
init initFunc) (ctx C.GO_EVP_PKEY_CTX_PTR, err error) {
defer func() {
if err != nil {
if ctx != nil {
Expand All @@ -112,7 +112,7 @@ func setupEVP(withKey withKeyFunc, padding C.int,
}
}()

withKey(func(pkey *C.EVP_PKEY) C.int {
withKey(func(pkey C.GO_EVP_PKEY_PTR) C.int {
ctx = C.go_openssl_EVP_PKEY_CTX_new(pkey, nil)
return 1
})
Expand Down Expand Up @@ -230,40 +230,40 @@ func cryptEVP(withKey withKeyFunc, padding C.int,
}

func evpEncrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []byte) ([]byte, error) {
encryptInit := func(ctx *C.EVP_PKEY_CTX) C.int {
encryptInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) C.int {
return C.go_openssl_EVP_PKEY_encrypt_init(ctx)
}
encrypt := func(ctx *C.EVP_PKEY_CTX, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
encrypt := func(ctx C.GO_EVP_PKEY_CTX_PTR, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
return C.go_openssl_EVP_PKEY_encrypt(ctx, out, outLen, in, inLen)
}
return cryptEVP(withKey, padding, h, label, 0, 0, encryptInit, encrypt, nil, msg)
}

func evpDecrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []byte) ([]byte, error) {
decryptInit := func(ctx *C.EVP_PKEY_CTX) C.int {
decryptInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) C.int {
return C.go_openssl_EVP_PKEY_decrypt_init(ctx)
}
decrypt := func(ctx *C.EVP_PKEY_CTX, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
decrypt := func(ctx C.GO_EVP_PKEY_CTX_PTR, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
return C.go_openssl_EVP_PKEY_decrypt(ctx, out, outLen, in, inLen)
}
return cryptEVP(withKey, padding, h, label, 0, 0, decryptInit, decrypt, nil, msg)
}

func evpSign(withKey withKeyFunc, padding C.int, saltLen int, h crypto.Hash, hashed []byte) ([]byte, error) {
signtInit := func(ctx *C.EVP_PKEY_CTX) C.int {
signtInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) C.int {
return C.go_openssl_EVP_PKEY_sign_init(ctx)
}
sign := func(ctx *C.EVP_PKEY_CTX, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
sign := func(ctx C.GO_EVP_PKEY_CTX_PTR, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
return C.go_openssl_EVP_PKEY_sign(ctx, out, outLen, in, inLen)
}
return cryptEVP(withKey, padding, nil, nil, saltLen, h, signtInit, sign, nil, hashed)
}

func evpVerify(withKey withKeyFunc, padding C.int, saltLen int, h crypto.Hash, sig, hashed []byte) error {
verifyInit := func(ctx *C.EVP_PKEY_CTX) C.int {
verifyInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) C.int {
return C.go_openssl_EVP_PKEY_verify_init(ctx)
}
verify := func(ctx *C.EVP_PKEY_CTX, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
verify := func(ctx C.GO_EVP_PKEY_CTX_PTR, out *C.uint8_t, outLen *C.uint, in *C.uint8_t, inLen C.uint) C.int {
return C.go_openssl_EVP_PKEY_verify(ctx, out, *outLen, in, inLen)
}
_, err := cryptEVP(withKey, padding, nil, nil, saltLen, h, verifyInit, verify, sig, hashed)
Expand Down
3 changes: 3 additions & 0 deletions openssl/goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ void go_openssl_load_functions(void* handle, const void* v1_0_sentinel, const vo
typedef void* GO_EVP_CIPHER_PTR;
typedef void* GO_EVP_CIPHER_CTX_PTR;

typedef void* GO_EVP_PKEY_PTR;
typedef void* GO_EVP_PKEY_CTX_PTR;

// Define pointers to all the used OpenSSL functions.
// Calling C function pointers from Go is currently not supported.
// It is possible to circumvent this by using a C function wrapper.
Expand Down
40 changes: 20 additions & 20 deletions openssl/openssl_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,33 +154,33 @@ DEFINEFUNC(const GO_EVP_CIPHER_PTR, EVP_aes_256_ecb, (void), ()) \
DEFINEFUNC(const GO_EVP_CIPHER_PTR, EVP_aes_256_gcm, (void), ()) \
DEFINEFUNC(void, EVP_CIPHER_CTX_free, (GO_EVP_CIPHER_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_CIPHER_CTX_ctrl, (GO_EVP_CIPHER_CTX_PTR ctx, int type, int arg, void *ptr), (ctx, type, arg, ptr)) \
DEFINEFUNC(EVP_PKEY *, EVP_PKEY_new, (void), ()) \
DEFINEFUNC_RENAMED(int, EVP_PKEY_get_size, EVP_PKEY_size, (const EVP_PKEY *pkey), (pkey)) \
DEFINEFUNC(void, EVP_PKEY_free, (EVP_PKEY * arg0), (arg0)) \
DEFINEFUNC(EC_KEY *, EVP_PKEY_get1_EC_KEY, (EVP_PKEY *pkey), (pkey)) \
DEFINEFUNC(RSA *, EVP_PKEY_get1_RSA, (EVP_PKEY *pkey), (pkey)) \
DEFINEFUNC(int, EVP_PKEY_assign, (EVP_PKEY *pkey, int type, void *key), (pkey, type, key)) \
DEFINEFUNC(GO_EVP_PKEY_PTR, EVP_PKEY_new, (void), ()) \
DEFINEFUNC_RENAMED(int, EVP_PKEY_get_size, EVP_PKEY_size, (const GO_EVP_PKEY_PTR pkey), (pkey)) \
DEFINEFUNC(void, EVP_PKEY_free, (GO_EVP_PKEY_PTR arg0), (arg0)) \
DEFINEFUNC(EC_KEY *, EVP_PKEY_get1_EC_KEY, (GO_EVP_PKEY_PTR pkey), (pkey)) \
DEFINEFUNC(RSA *, EVP_PKEY_get1_RSA, (GO_EVP_PKEY_PTR pkey), (pkey)) \
DEFINEFUNC(int, EVP_PKEY_assign, (GO_EVP_PKEY_PTR pkey, int type, void *key), (pkey, type, key)) \
DEFINEFUNC(int, EVP_PKEY_verify, \
(EVP_PKEY_CTX *ctx, const uint8_t *sig, unsigned int siglen, const uint8_t *tbs, unsigned int tbslen), \
(GO_EVP_PKEY_CTX_PTR ctx, const uint8_t *sig, unsigned int siglen, const uint8_t *tbs, unsigned int tbslen), \
(ctx, sig, siglen, tbs, tbslen)) \
DEFINEFUNC(EVP_PKEY_CTX *, EVP_PKEY_CTX_new, (EVP_PKEY * arg0, ENGINE *arg1), (arg0, arg1)) \
DEFINEFUNC(EVP_PKEY_CTX *, EVP_PKEY_CTX_new_id, (int id, ENGINE *e), (id, e)) \
DEFINEFUNC(int, EVP_PKEY_keygen_init, (EVP_PKEY_CTX *ctx), (ctx)) \
DEFINEFUNC(int, EVP_PKEY_keygen, (EVP_PKEY_CTX *ctx, EVP_PKEY **ppkey), (ctx, ppkey)) \
DEFINEFUNC(void, EVP_PKEY_CTX_free, (EVP_PKEY_CTX * arg0), (arg0)) \
DEFINEFUNC(GO_EVP_PKEY_CTX_PTR, EVP_PKEY_CTX_new, (GO_EVP_PKEY_PTR arg0, ENGINE *arg1), (arg0, arg1)) \
DEFINEFUNC(GO_EVP_PKEY_CTX_PTR, EVP_PKEY_CTX_new_id, (int id, ENGINE *e), (id, e)) \
DEFINEFUNC(int, EVP_PKEY_keygen_init, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \
DEFINEFUNC(int, EVP_PKEY_keygen, (GO_EVP_PKEY_CTX_PTR ctx, GO_EVP_PKEY_PTR *ppkey), (ctx, ppkey)) \
DEFINEFUNC(void, EVP_PKEY_CTX_free, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_CTX_ctrl, \
(EVP_PKEY_CTX * ctx, int keytype, int optype, int cmd, int p1, void *p2), \
(GO_EVP_PKEY_CTX_PTR ctx, int keytype, int optype, int cmd, int p1, void *p2), \
(ctx, keytype, optype, cmd, p1, p2)) \
DEFINEFUNC(int, EVP_PKEY_decrypt, \
(EVP_PKEY_CTX * arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(GO_EVP_PKEY_CTX_PTR arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(arg0, arg1, arg2, arg3, arg4)) \
DEFINEFUNC(int, EVP_PKEY_encrypt, \
(EVP_PKEY_CTX * arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(GO_EVP_PKEY_CTX_PTR arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(arg0, arg1, arg2, arg3, arg4)) \
DEFINEFUNC(int, EVP_PKEY_decrypt_init, (EVP_PKEY_CTX * arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_encrypt_init, (EVP_PKEY_CTX * arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_sign_init, (EVP_PKEY_CTX * arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_verify_init, (EVP_PKEY_CTX * arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_decrypt_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_encrypt_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_sign_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_verify_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \
DEFINEFUNC(int, EVP_PKEY_sign, \
(EVP_PKEY_CTX * arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(GO_EVP_PKEY_CTX_PTR arg0, uint8_t *arg1, unsigned int *arg2, const uint8_t *arg3, unsigned int arg4), \
(arg0, arg1, arg2, arg3, arg4))
10 changes: 5 additions & 5 deletions openssl/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)

type PublicKeyRSA struct {
// _pkey MUST NOT be accessed directly. Instead, use the withKey method.
_pkey *C.EVP_PKEY
_pkey C.GO_EVP_PKEY_PTR
}

func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
Expand Down Expand Up @@ -69,7 +69,7 @@ func (k *PublicKeyRSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PublicKeyRSA) withKey(f func(*C.EVP_PKEY) C.int) C.int {
func (k *PublicKeyRSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
// Because of the finalizer, any time _pkey is passed to cgo, that call must
// be followed by a call to runtime.KeepAlive, to make sure k is not
// collected (and finalized) before the cgo call returns.
Expand All @@ -79,7 +79,7 @@ func (k *PublicKeyRSA) withKey(f func(*C.EVP_PKEY) C.int) C.int {

type PrivateKeyRSA struct {
// _pkey MUST NOT be accessed directly. Instead, use the withKey method.
_pkey *C.EVP_PKEY
_pkey C.GO_EVP_PKEY_PTR
}

func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
Expand Down Expand Up @@ -119,7 +119,7 @@ func (k *PrivateKeyRSA) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func (k *PrivateKeyRSA) withKey(f func(*C.EVP_PKEY) C.int) C.int {
func (k *PrivateKeyRSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
// Because of the finalizer, any time _pkey is passed to cgo, that call must
// be followed by a call to runtime.KeepAlive, to make sure k is not
// collected (and finalized) before the cgo call returns.
Expand Down Expand Up @@ -190,7 +190,7 @@ func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte,
}

func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error {
if pub.withKey(func(pkey *C.EVP_PKEY) C.int {
if pub.withKey(func(pkey C.GO_EVP_PKEY_PTR) C.int {
size := C.go_openssl_EVP_PKEY_get_size(pkey)
if len(sig) < int(size) {
return 0
Expand Down
27 changes: 27 additions & 0 deletions openssl/rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package openssl
import (
"bytes"
"crypto"
"math/big"
"strconv"
"testing"
)
Expand Down Expand Up @@ -144,3 +145,29 @@ func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) {
}
return priv, pub
}

func fromBase10(base10 string) *big.Int {
i, ok := new(big.Int).SetString(base10, 10)
if !ok {
panic("bad number: " + base10)
}
return i
}

func BenchmarkEncryptRSAPKCS1(b *testing.B) {
b.StopTimer()
test2048PubKey, err := NewPublicKeyRSA(
fromBase10("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557"),
big.NewInt(3),
)
if err != nil {
b.Fatal(err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil {
b.Fatal(err)
}
}
}