Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AES performance #12

Merged
merged 6 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions openssl/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ const aesBlockSize = 16

type aesCipher struct {
key []byte
enc_ctx *C.EVP_CIPHER_CTX
dec_ctx *C.EVP_CIPHER_CTX
cipher *C.EVP_CIPHER
enc_ctx C.GO_EVP_CIPHER_CTX_PTR
dec_ctx C.GO_EVP_CIPHER_CTX_PTR
cipher C.GO_EVP_CIPHER_PTR
}

type extraModes interface {
Expand Down Expand Up @@ -96,8 +96,7 @@ func (c *aesCipher) Encrypt(dst, src []byte) {
}
}

outlen := C.int(0)
C.go_openssl_EVP_CipherUpdate(c.enc_ctx, (*C.uchar)(unsafe.Pointer(&dst[0])), &outlen, (*C.uchar)(unsafe.Pointer(&src[0])), C.int(aesBlockSize))
C.go_openssl_EVP_EncryptUpdate_wrapper(c.enc_ctx, base(dst), base(src), aesBlockSize)
runtime.KeepAlive(c)
}

Expand All @@ -119,13 +118,12 @@ func (c *aesCipher) Decrypt(dst, src []byte) {
}
}

outlen := C.int(0)
C.go_openssl_EVP_CipherUpdate(c.dec_ctx, (*C.uchar)(unsafe.Pointer(&dst[0])), &outlen, (*C.uchar)(unsafe.Pointer(&src[0])), C.int(aesBlockSize))
C.go_openssl_EVP_DecryptUpdate_wrapper(c.dec_ctx, base(dst), base(src), aesBlockSize)
runtime.KeepAlive(c)
}

type aesCBC struct {
ctx *C.EVP_CIPHER_CTX
ctx C.GO_EVP_CIPHER_CTX_PTR
}

func (x *aesCBC) BlockSize() int { return aesBlockSize }
Expand All @@ -141,12 +139,7 @@ func (x *aesCBC) CryptBlocks(dst, src []byte) {
panic("crypto/cipher: output smaller than input")
}
if len(src) > 0 {
outlen := C.int(0)
if C.go_openssl_EVP_CipherUpdate(
x.ctx,
base(dst), &outlen,
base(src), C.int(len(src)),
) != C.int(1) {
if C.go_openssl_EVP_CipherUpdate_wrapper(x.ctx, base(dst), base(src), C.int(len(src))) != 1 {
panic("crypto/cipher: CipherUpdate failed")
}
runtime.KeepAlive(x)
Expand All @@ -157,15 +150,15 @@ func (x *aesCBC) SetIV(iv []byte) {
if len(iv) != aesBlockSize {
panic("cipher: incorrect length IV")
}
if C.int(1) != C.go_openssl_EVP_CipherInit_ex(x.ctx, nil, nil, nil, (*C.uchar)(unsafe.Pointer(&iv[0])), -1) {
if C.go_openssl_EVP_CipherInit_ex(x.ctx, nil, nil, nil, base(iv), -1) != 1 {
panic("cipher: unable to initialize EVP cipher ctx")
}
}

func (c *aesCipher) NewCBCEncrypter(iv []byte) cipher.BlockMode {
x := new(aesCBC)

var cipher *C.EVP_CIPHER
var cipher C.GO_EVP_CIPHER_PTR
switch len(c.key) * 8 {
case 128:
cipher = C.go_openssl_EVP_aes_128_cbc()
Expand Down Expand Up @@ -194,7 +187,7 @@ func (c *aesCBC) finalize() {
func (c *aesCipher) NewCBCDecrypter(iv []byte) cipher.BlockMode {
x := new(aesCBC)

var cipher *C.EVP_CIPHER
var cipher C.GO_EVP_CIPHER_PTR
switch len(c.key) * 8 {
case 128:
cipher = C.go_openssl_EVP_aes_128_cbc()
Expand All @@ -211,7 +204,7 @@ func (c *aesCipher) NewCBCDecrypter(iv []byte) cipher.BlockMode {
if err != nil {
panic(err)
}
if C.int(1) != C.go_openssl_EVP_CIPHER_CTX_set_padding(x.ctx, 0) {
if C.go_openssl_EVP_CIPHER_CTX_set_padding(x.ctx, 0) != 1 {
panic("cipher: unable to set padding")
}

Expand All @@ -220,7 +213,7 @@ func (c *aesCipher) NewCBCDecrypter(iv []byte) cipher.BlockMode {
}

type aesCTR struct {
ctx *C.EVP_CIPHER_CTX
ctx C.GO_EVP_CIPHER_CTX_PTR
}

func (x *aesCTR) XORKeyStream(dst, src []byte) {
Expand All @@ -233,18 +226,14 @@ func (x *aesCTR) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
C.go_openssl_EVP_EncryptUpdate_wrapper(
x.ctx,
(*C.uint8_t)(unsafe.Pointer(&dst[0])),
(*C.uint8_t)(unsafe.Pointer(&src[0])),
C.size_t(len(src)))
C.go_openssl_EVP_EncryptUpdate_wrapper(x.ctx, base(dst), base(src), C.int(len(src)))
runtime.KeepAlive(x)
}

func (c *aesCipher) NewCTR(iv []byte) cipher.Stream {
x := new(aesCTR)

var cipher *C.EVP_CIPHER
var cipher C.GO_EVP_CIPHER_PTR
switch len(c.key) * 8 {
case 128:
cipher = C.go_openssl_EVP_aes_128_ctr()
Expand Down Expand Up @@ -273,7 +262,7 @@ func (c *aesCTR) finalize() {
type aesGCM struct {
key []byte
tls bool
cipher *C.EVP_CIPHER
cipher C.GO_EVP_CIPHER_PTR
}

const (
Expand Down Expand Up @@ -370,23 +359,23 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {

var encLen C.int
// Encrypt additional data.
if C.go_openssl_EVP_EncryptUpdate(ctx, nil, &encLen, base(additionalData), C.int(len(additionalData))) != C.int(1) {
if C.go_openssl_EVP_EncryptUpdate(ctx, nil, &encLen, base(additionalData), C.int(len(additionalData))) != 1 {
panic(fail("EVP_CIPHER_CTX_seal"))
}

// Encrypt plain text.
if C.go_openssl_EVP_EncryptUpdate(ctx, base(out), &encLen, base(plaintext), C.int(len(plaintext))) != C.int(1) {
if C.go_openssl_EVP_EncryptUpdate(ctx, base(out), &encLen, base(plaintext), C.int(len(plaintext))) != 1 {
panic(fail("EVP_CIPHER_CTX_seal"))
}

// Finalise encryption.
var encFinalLen C.int
if C.go_openssl_EVP_EncryptFinal_ex(ctx, base(out[encLen:]), &encFinalLen) != C.int(1) {
if C.go_openssl_EVP_EncryptFinal_ex(ctx, base(out[encLen:]), &encFinalLen) != 1 {
panic(fail("EVP_CIPHER_CTX_seal"))
}
encLen += encFinalLen

if C.go_openssl_EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_GCM_GET_TAG, C.int(16), unsafe.Pointer(&out[encLen])) != C.int(1) {
if C.go_openssl_EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_GCM_GET_TAG, 16, unsafe.Pointer(&out[encLen])) != 1 {
panic(fail("EVP_CIPHER_CTX_seal"))
}
encLen += 16
Expand Down Expand Up @@ -435,25 +424,24 @@ func (g *aesGCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, er
}

// Provide any AAD data.
var tmplen C.int
if C.go_openssl_EVP_DecryptUpdate(ctx, nil, &tmplen, base(additionalData), C.int(len(additionalData))) != C.int(1) {
var decLen C.int
if C.go_openssl_EVP_DecryptUpdate(ctx, nil, &decLen, base(additionalData), C.int(len(additionalData))) != 1 {
return clearAndFail(errOpen)
}

// Provide the message to be decrypted, and obtain the plaintext output.
var decLen C.int
if C.go_openssl_EVP_DecryptUpdate(ctx, base(out), &decLen, base(ciphertext), C.int(len(ciphertext))) != C.int(1) {
if C.go_openssl_EVP_DecryptUpdate(ctx, base(out), &decLen, base(ciphertext), C.int(len(ciphertext))) != 1 {
return clearAndFail(errOpen)
}

// Set expected tag value. Works in OpenSSL 1.0.1d and later.
if C.go_openssl_EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_GCM_SET_TAG, 16, unsafe.Pointer(&tag[0])) != C.int(1) {
if C.go_openssl_EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_GCM_SET_TAG, 16, unsafe.Pointer(&tag[0])) != 1 {
return clearAndFail(errOpen)
}

// Finalise the decryption.
var tagLen C.int
if C.go_openssl_EVP_DecryptFinal_ex(ctx, base(out[int(decLen):]), &tagLen) != C.int(1) {
if C.go_openssl_EVP_DecryptFinal_ex(ctx, base(out[int(decLen):]), &tagLen) != 1 {
return clearAndFail(errOpen)
}

Expand All @@ -475,12 +463,12 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) {
return
}

func newCipherCtx(cipher *C.EVP_CIPHER, mode C.int, key, iv []byte) (*C.EVP_CIPHER_CTX, error) {
func newCipherCtx(cipher C.GO_EVP_CIPHER_PTR, mode C.int, key, iv []byte) (C.GO_EVP_CIPHER_CTX_PTR, error) {
ctx := C.go_openssl_EVP_CIPHER_CTX_new()
if ctx == nil {
return nil, fail("unable to create EVP cipher ctx")
}
if C.int(1) != C.go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), mode) {
if C.go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), mode) != 1 {
return nil, fail("unable to initialize EVP cipher ctx")
}
return ctx, nil
Expand Down
82 changes: 80 additions & 2 deletions openssl/aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func testDecrypt(t *testing.T, resetNonce bool) {
111, 184, 94, 169, 188, 93, 38, 150,
3, 208, 185, 201, 212, 246, 238, 181,
}
if bytes.Compare(expectedCipherText, cipherText) != 0 {
if !bytes.Equal(expectedCipherText, cipherText) {
t.Fail()
}

Expand All @@ -257,7 +257,7 @@ func testDecrypt(t *testing.T, resetNonce bool) {
t.Fail()
}

if bytes.Compare(plainText, decrypted) != 0 {
if !bytes.Equal(plainText, decrypted) {
t.Errorf("decryption incorrect\nexp %v, got %v\n", plainText, decrypted)
}
}
Expand All @@ -279,3 +279,81 @@ func Test_aesCipher_finalize(t *testing.T) {
// in case test execution takes long enough, and it can't be finalized twice.
new(aesCipher).finalize()
}

func BenchmarkAES_Encrypt(b *testing.B) {
key := []byte{0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c}
in := []byte{0x32, 0x43, 0xf6, 0xa8, 0x88, 0x5a, 0x30, 0x8d, 0x31, 0x31, 0x98, 0xa2, 0xe0, 0x37, 0x07, 0x34}
c, err := NewAESCipher(key)
if err != nil {
b.Fatal("NewCipher:", err)
}
out := make([]byte, len(in))
b.SetBytes(int64(len(out)))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.Encrypt(out, in)
}
}

func BenchmarkAES_Decrypt(b *testing.B) {
key := []byte{0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c}
src := []byte{0x39, 0x25, 0x84, 0x1d, 0x02, 0xdc, 0x09, 0xfb, 0xdc, 0x11, 0x85, 0x97, 0x19, 0x6a, 0x0b, 0x32}
c, err := NewAESCipher(key)
if err != nil {
b.Fatal("NewCipher:", err)
}
out := make([]byte, len(src))
b.SetBytes(int64(len(src)))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.Encrypt(out, src)
}
}

func BenchmarkAESGCM_Open(b *testing.B) {
const length = 64
const keySize = 128 / 8
buf := make([]byte, length)

b.ReportAllocs()
b.SetBytes(int64(len(buf)))

var key = make([]byte, keySize)
var nonce [12]byte
var ad [13]byte
c, _ := NewAESCipher(key)
aesgcm, _ := c.(extraModes).NewGCM(gcmStandardNonceSize, gcmTagSize)
var out []byte

ct := aesgcm.Seal(nil, nonce[:], buf[:], ad[:])

b.ResetTimer()
for i := 0; i < b.N; i++ {
out, _ = aesgcm.Open(out[:0], nonce[:], ct, ad[:])
}
}

func BenchmarkAESGCM_Seal(b *testing.B) {
const length = 64
const keySize = 128 / 8
buf := make([]byte, length)

b.ReportAllocs()
b.SetBytes(int64(len(buf)))

var key = make([]byte, keySize)
var nonce [12]byte
var ad [13]byte
c, _ := NewAESCipher(key)
aesgcm, _ := c.(extraModes).NewGCM(gcmStandardNonceSize, gcmTagSize)
var out []byte

ct := aesgcm.Seal(nil, nonce[:], buf[:], ad[:])

b.ResetTimer()
for i := 0; i < b.N; i++ {
out, _ = aesgcm.Open(out[:0], nonce[:], ct, ad[:])
}
}
27 changes: 22 additions & 5 deletions openssl/goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ void go_openssl_load_functions(void* handle, const void* v1_0_sentinel, const vo
#define GO_AES_ENCRYPT 1
#define GO_AES_DECRYPT 0

typedef void* GO_EVP_CIPHER_PTR;
typedef void* GO_EVP_CIPHER_CTX_PTR;

// Define pointers to all the used OpenSSL functions.
// Calling C function pointers from Go is currently not supported.
// It is possible to circumvent this by using a C function wrapper.
Expand Down Expand Up @@ -65,11 +68,25 @@ FOR_ALL_OPENSSL_FUNCTIONS
#undef DEFINEFUNC_3_0
#undef DEFINEFUNC_RENAMED

// This wrapper allocate out_len on the C stack, and check that it matches the expected
// value, to avoid having to pass a pointer from Go, which would escape to the heap.
static inline void
go_openssl_EVP_EncryptUpdate_wrapper(EVP_CIPHER_CTX *ctx, uint8_t *out, const uint8_t *in, size_t in_len)
// These wrappers allocate out_len on the C stack to avoid having to pass a pointer from Go, which would escape to the heap.
// Use them only in situations where the output length can be safely discarded.
static inline int
go_openssl_EVP_EncryptUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, uint8_t *out, const uint8_t *in, int in_len)
{
int len;
return go_openssl_EVP_EncryptUpdate(ctx, out, &len, in, in_len);
}

static inline int
go_openssl_EVP_DecryptUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, uint8_t *out, const uint8_t *in, int in_len)
{
int len;
return go_openssl_EVP_DecryptUpdate(ctx, out, &len, in, in_len);
}

static inline int
go_openssl_EVP_CipherUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, uint8_t *out, const uint8_t *in, int in_len)
{
int len;
go_openssl_EVP_EncryptUpdate(ctx, out, &len, in, in_len);
return go_openssl_EVP_CipherUpdate(ctx, out, &len, in, in_len);
}
Loading