diff --git a/openssl/aes_test.go b/openssl/aes_test.go index c06fcc5..8428c39 100644 --- a/openssl/aes_test.go +++ b/openssl/aes_test.go @@ -9,7 +9,6 @@ package openssl import ( "bytes" "crypto/cipher" - "math" "testing" ) @@ -36,6 +35,7 @@ func TestNewGCMNonce(t *testing.T) { if err != nil { t.Errorf("expected no error for standard tag / nonce size, got: %#v", err) } + c.finalize() } func TestSealAndOpen(t *testing.T) { @@ -60,6 +60,7 @@ func TestSealAndOpen(t *testing.T) { if !bytes.Equal(decrypted, plainText) { t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) } + c.finalize() } func TestSealAndOpenAuthenticationError(t *testing.T) { @@ -81,6 +82,7 @@ func TestSealAndOpenAuthenticationError(t *testing.T) { if err != errOpen { t.Errorf("expected authentication error, got: %#v", err) } + c.finalize() } func assertPanic(t *testing.T, f func()) { @@ -107,7 +109,11 @@ func TestSealPanic(t *testing.T) { gcm.Seal(nil, make([]byte, gcmStandardNonceSize-1), []byte{0x01, 0x02, 0x03}, nil) }) assertPanic(t, func() { - gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, math.MaxInt), nil) + // maxInt is implemented as math.MaxInt, but this constant + // is only available since go1.17. + // TODO: use math.MaxInt once go1.16 is no longer supported. + maxInt := int((^uint(0)) >> 1) + gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, maxInt), nil) }) } diff --git a/openssl/ecdsa_test.go b/openssl/ecdsa_test.go new file mode 100644 index 0000000..132a488 --- /dev/null +++ b/openssl/ecdsa_test.go @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "testing" +) + +func testAllCurves(t *testing.T, f func(*testing.T, elliptic.Curve)) { + tests := []struct { + name string + curve elliptic.Curve + }{ + {"P256", elliptic.P256()}, + {"P224", elliptic.P224()}, + {"P384", elliptic.P384()}, + {"P521", elliptic.P521()}, + } + for _, test := range tests { + curve := test.curve + t.Run(test.name, func(t *testing.T) { + t.Parallel() + f(t, curve) + }) + } +} + +func TestECDSAKeyGeneration(t *testing.T) { + testAllCurves(t, testECDSAKeyGeneration) +} + +func testECDSAKeyGeneration(t *testing.T, c elliptic.Curve) { + priv, err := generateKeycurve(c) + if err != nil { + t.Fatal(err) + } + if !c.IsOnCurve(priv.PublicKey.X, priv.PublicKey.Y) { + t.Errorf("public key invalid: %s", err) + } +} + +func TestECDSASignAndVerify(t *testing.T) { + testAllCurves(t, testECDSASignAndVerify) +} + +func testECDSASignAndVerify(t *testing.T, c elliptic.Curve) { + key, err := generateKeycurve(c) + if err != nil { + t.Fatal(err) + } + + priv, err := NewPrivateKeyECDSA(key.Params().Name, key.X, key.Y, key.D) + if err != nil { + t.Fatal(err) + } + hashed := []byte("testing") + r, s, err := SignECDSA(priv, hashed) + if err != nil { + t.Errorf("error signing: %s", err) + return + } + + pub, err := NewPublicKeyECDSA(key.Params().Name, key.X, key.Y) + if err != nil { + t.Fatal(err) + } + if !VerifyECDSA(pub, hashed, r, s) { + t.Errorf("Verify failed") + } + hashed[0] ^= 0xff + if VerifyECDSA(pub, hashed, r, s) { + t.Errorf("Verify succeeded despite intentionally invalid hash!") + } +} + +func generateKeycurve(c elliptic.Curve) (*ecdsa.PrivateKey, error) { + x, y, d, err := GenerateKeyECDSA(c.Params().Name) + if err != nil { + return nil, err + } + return &ecdsa.PrivateKey{PublicKey: ecdsa.PublicKey{Curve: c, X: x, Y: y}, D: d}, nil +} diff --git a/openssl/hmac_test.go b/openssl/hmac_test.go index 097b584..1b6a6e8 100644 --- a/openssl/hmac_test.go +++ b/openssl/hmac_test.go @@ -7,14 +7,58 @@ package openssl import ( + "bytes" + "hash" "testing" ) -// Just tests that we can create an HMAC instance. -// Previously would cause panic because of incorrect -// stack allocation of opaque OpenSSL type. -func TestNewHMAC(t *testing.T) { - mac := NewHMAC(NewSHA256, nil) - mac.Write([]byte("foo")) - t.Logf("%x\n", mac.Sum(nil)) +func TestHMAC(t *testing.T) { + var tests = []struct { + name string + fn func() hash.Hash + }{ + {"sha1", NewSHA1}, + {"sha224", NewSHA224}, + {"sha256", NewSHA256}, + {"sha384", NewSHA384}, + {"sha512", NewSHA512}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := NewHMAC(tt.fn, nil) + h.Write([]byte("hello")) + sumHello := h.Sum(nil) + + h = NewHMAC(tt.fn, nil) + h.Write([]byte("hello world")) + sumHelloWorld := h.Sum(nil) + + // Test that Sum has no effect on future Sum or Write operations. + // This is a bit unusual as far as usage, but it's allowed + // by the definition of Go hash.Hash, and some clients expect it to work. + h = NewHMAC(tt.fn, nil) + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("1st Sum after hello = %x, want %x", sum, sumHello) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("2nd Sum after hello = %x, want %x", sum, sumHello) + } + + h.Write([]byte(" world")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("1st Sum after hello world = %x, want %x", sum, sumHelloWorld) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("2nd Sum after hello world = %x, want %x", sum, sumHelloWorld) + } + + h.Reset() + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("Sum after Reset + hello = %x, want %x", sum, sumHello) + } + }) + } } diff --git a/openssl/rand_test.go b/openssl/rand_test.go new file mode 100644 index 0000000..ee8bcf0 --- /dev/null +++ b/openssl/rand_test.go @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "testing" +) + +func TestRand(t *testing.T) { + _, err := RandReader.Read(make([]byte, 5)) + if err != nil { + t.Fatal(err) + } +} diff --git a/openssl/rsa_test.go b/openssl/rsa_test.go new file mode 100644 index 0000000..a9849f8 --- /dev/null +++ b/openssl/rsa_test.go @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "bytes" + "crypto" + "strconv" + "testing" +) + +func TestRSAKeyGeneration(t *testing.T) { + for _, size := range []int{2048, 3072} { + t.Run(strconv.Itoa(size), func(t *testing.T) { + t.Parallel() + priv, pub := newRSAKey(t, size) + msg := []byte("hi!") + enc, err := EncryptRSAPKCS1(pub, msg) + if err != nil { + t.Fatalf("EncryptPKCS1v15: %v", err) + } + dec, err := DecryptRSAPKCS1(priv, enc) + if err != nil { + t.Fatalf("DecryptPKCS1v15: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("got:%x want:%x", dec, msg) + } + }) + } +} + +func TestEncryptDecryptOAEP(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + label := []byte("ho!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, label) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, label) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + if err == nil { + t.Errorf("error expected") + } + if dec != nil { + t.Errorf("got:%x want: nil", dec) + } +} + +func TestSignVerifyPKCS1v15(t *testing.T) { + sha256 := NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("hi!")) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + signed, err := SignRSAPKCS1v15(priv, 0, msg) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + sha256.Write(msg) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + if err == nil { + t.Fatal("error expected") + } +} + +func TestSignVerifyRSAPSS(t *testing.T) { + sha1 := NewSHA1() + priv, pub := newRSAKey(t, 2048) + sha1.Write([]byte("testing")) + hashed := sha1.Sum(nil) + signed, err := SignRSAPSS(priv, crypto.SHA1, hashed, 0) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPSS(pub, crypto.SHA1, hashed, signed, 0) + if err != nil { + t.Fatal(err) + } +} + +func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { + t.Helper() + N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + if err != nil { + t.Fatalf("GenerateKeyRSA(%d): %v", size, err) + } + priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + if err != nil { + t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + } + pub, err := NewPublicKeyRSA(N, E) + if err != nil { + t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + } + return priv, pub +} diff --git a/openssl/sha_test.go b/openssl/sha_test.go new file mode 100644 index 0000000..ef43037 --- /dev/null +++ b/openssl/sha_test.go @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "bytes" + "encoding" + "hash" + "testing" +) + +func TestSha(t *testing.T) { + msg := []byte("testig") + var tests = []struct { + name string + fn func() hash.Hash + }{ + {"sha1", NewSHA1}, + {"sha224", NewSHA224}, + {"sha256", NewSHA256}, + {"sha384", NewSHA384}, + {"sha512", NewSHA512}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := tt.fn() + initSum := h.Sum(nil) + n, err := h.Write(msg) + if err != nil { + t.Fatal(err) + } + if n != len(msg) { + t.Errorf("got: %d, want: %d", n, len(msg)) + } + sum := h.Sum(nil) + if size := h.Size(); len(sum) != size { + t.Errorf("got: %d, want: %d", len(sum), size) + } + if bytes.Equal(sum, initSum) { + t.Error("Write didn't change internal hash state") + } + + state, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Errorf("could not marshal: %v", err) + } + h2 := tt.fn() + if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { + t.Errorf("could not unmarshal: %v", err) + } + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("0x%x != marshaled 0x%x", actual, actual2) + } + + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } + }) + } +}