From 0be062ac593b9041ea3fbb0f9b414fdea6256dc6 Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Tue, 25 Jan 2022 11:01:50 +0000 Subject: [PATCH 1/6] increase test coverage --- openssl/aes_test.go | 11 +++- openssl/ecdsa_test.go | 88 +++++++++++++++++++++++++ openssl/hmac_test.go | 45 +++++++++++-- openssl/rand_test.go | 18 +++++ openssl/rsa_test.go | 148 ++++++++++++++++++++++++++++++++++++++++++ openssl/sha_test.go | 55 ++++++++++++++++ 6 files changed, 356 insertions(+), 9 deletions(-) create mode 100644 openssl/ecdsa_test.go create mode 100644 openssl/rand_test.go create mode 100644 openssl/rsa_test.go create mode 100644 openssl/sha_test.go diff --git a/openssl/aes_test.go b/openssl/aes_test.go index c06fcc5..8f19012 100644 --- a/openssl/aes_test.go +++ b/openssl/aes_test.go @@ -9,7 +9,6 @@ package openssl import ( "bytes" "crypto/cipher" - "math" "testing" ) @@ -36,6 +35,7 @@ func TestNewGCMNonce(t *testing.T) { if err != nil { t.Errorf("expected no error for standard tag / nonce size, got: %#v", err) } + c.finalize() } func TestSealAndOpen(t *testing.T) { @@ -60,6 +60,7 @@ func TestSealAndOpen(t *testing.T) { if !bytes.Equal(decrypted, plainText) { t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) } + c.finalize() } func TestSealAndOpenAuthenticationError(t *testing.T) { @@ -81,6 +82,7 @@ func TestSealAndOpenAuthenticationError(t *testing.T) { if err != errOpen { t.Errorf("expected authentication error, got: %#v", err) } + c.finalize() } func assertPanic(t *testing.T, f func()) { @@ -107,7 +109,12 @@ func TestSealPanic(t *testing.T) { gcm.Seal(nil, make([]byte, gcmStandardNonceSize-1), []byte{0x01, 0x02, 0x03}, nil) }) assertPanic(t, func() { - gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, math.MaxInt), nil) + // maxInt is implemented as math.MaxInt, but this constant + // is only available since go1.17. + // TODO: use math.MaxInt once go1.16 is no longer supported. + const intSize = 32 << (^uint(0) >> 63) // 32 or 64 + maxInt := 1<<(intSize-1) - 1 + gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, maxInt), nil) }) } diff --git a/openssl/ecdsa_test.go b/openssl/ecdsa_test.go new file mode 100644 index 0000000..c85c706 --- /dev/null +++ b/openssl/ecdsa_test.go @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "testing" +) + +func testAllCurves(t *testing.T, f func(*testing.T, elliptic.Curve)) { + tests := []struct { + name string + curve elliptic.Curve + }{ + {"P256", elliptic.P256()}, + {"P224", elliptic.P224()}, + {"P384", elliptic.P384()}, + {"P521", elliptic.P521()}, + } + for _, test := range tests { + curve := test.curve + t.Run(test.name, func(t *testing.T) { + t.Parallel() + f(t, curve) + }) + } +} + +func TestECDSAKeyGeneration(t *testing.T) { + testAllCurves(t, testECDSAKeyGeneration) +} + +func testECDSAKeyGeneration(t *testing.T, c elliptic.Curve) { + priv, err := generateKeycurve(c) + if err != nil { + t.Fatal(err) + } + if !c.IsOnCurve(priv.PublicKey.X, priv.PublicKey.Y) { + t.Errorf("public key invalid: %s", err) + } +} + +func TestECDSASignAndVerify(t *testing.T) { + testAllCurves(t, testECDSASignAndVerify) +} + +func testECDSASignAndVerify(t *testing.T, c elliptic.Curve) { + key, err := generateKeycurve(c) + if err != nil { + t.Fatal(err) + } + + priv, err := NewPrivateKeyECDSA(key.Params().Name, key.X, key.Y, key.D) + if err != nil { + t.Fatal(err) + } + hashed := []byte("testing") + r, s, err := SignECDSA(priv, hashed) + if err != nil { + t.Errorf("error signing: %s", err) + return + } + + pub, err := NewPublicKeyECDSA(key.Params().Name, key.X, key.Y) + if err != nil { + t.Fatal(err) + } + if !VerifyECDSA(pub, hashed, r, s) { + t.Errorf("Verify failed") + } + hashed[0] ^= 0xff + if VerifyECDSA(pub, hashed, r, s) { + t.Errorf("Verify always works!") + } +} + +func generateKeycurve(c elliptic.Curve) (*ecdsa.PrivateKey, error) { + x, y, d, err := GenerateKeyECDSA(c.Params().Name) + if err != nil { + return nil, err + } + return &ecdsa.PrivateKey{PublicKey: ecdsa.PublicKey{Curve: c, X: x, Y: y}, D: d}, nil +} diff --git a/openssl/hmac_test.go b/openssl/hmac_test.go index 097b584..5cbe830 100644 --- a/openssl/hmac_test.go +++ b/openssl/hmac_test.go @@ -7,14 +7,45 @@ package openssl import ( + "bytes" + "hash" "testing" ) -// Just tests that we can create an HMAC instance. -// Previously would cause panic because of incorrect -// stack allocation of opaque OpenSSL type. -func TestNewHMAC(t *testing.T) { - mac := NewHMAC(NewSHA256, nil) - mac.Write([]byte("foo")) - t.Logf("%x\n", mac.Sum(nil)) +func TestHMAC(t *testing.T) { + for i, fn := range []func() hash.Hash{NewSHA1, NewSHA224, NewSHA256, NewSHA384, NewSHA512} { + h := NewHMAC(fn, nil) + h.Write([]byte("hello")) + sumHello := h.Sum(nil) + + h = NewHMAC(fn, nil) + h.Write([]byte("hello world")) + sumHelloWorld := h.Sum(nil) + + // Test that Sum has no effect on future Sum or Write operations. + // This is a bit unusual as far as usage, but it's allowed + // by the definition of Go hash.Hash, and some clients expect it to work. + h = NewHMAC(fn, nil) + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("i %d: 1st Sum after hello = %x, want %x", i, sum, sumHello) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("i %d: 2nd Sum after hello = %x, want %x", i, sum, sumHello) + } + + h.Write([]byte(" world")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("i %d: 1st Sum after hello world = %x, want %x", i, sum, sumHelloWorld) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("i %d: 2nd Sum after hello world = %x, want %x", i, sum, sumHelloWorld) + } + + h.Reset() + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("i %d: Sum after Reset + hello = %x, want %x", i, sum, sumHello) + } + } } diff --git a/openssl/rand_test.go b/openssl/rand_test.go new file mode 100644 index 0000000..ee8bcf0 --- /dev/null +++ b/openssl/rand_test.go @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "testing" +) + +func TestRand(t *testing.T) { + _, err := RandReader.Read(make([]byte, 5)) + if err != nil { + t.Fatal(err) + } +} diff --git a/openssl/rsa_test.go b/openssl/rsa_test.go new file mode 100644 index 0000000..a6e360b --- /dev/null +++ b/openssl/rsa_test.go @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "bytes" + "crypto" + "testing" +) + +func TestRSAKeyGeneration(t *testing.T) { + for _, size := range []int{128, 1024, 2048, 3072} { + priv, pub := newRSAKey(t, size) + testRSAKeyBasics(t, priv, pub) + } +} + +func testRSAKeyBasics(t *testing.T, priv *PrivateKeyRSA, pub *PublicKeyRSA) { + // Cannot call encrypt/decrypt directly. Test via PKCS1v15. + msg := []byte("hi!") + enc, err := EncryptRSAPKCS1(pub, msg) + if err != nil { + t.Errorf("EncryptPKCS1v15: %v", err) + return + } + dec, err := DecryptRSAPKCS1(priv, enc) + if err != nil { + t.Errorf("DecryptPKCS1v15: %v", err) + return + } + if !bytes.Equal(dec, msg) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestEncryptDecryptOAEP(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + label := []byte("ho!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, label) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, label) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + if err == nil { + t.Errorf("error expected") + } + if dec != nil { + t.Errorf("got:%x want: nil", dec) + } +} + +func TestSignVerifyPKCS1v15(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + hashed := sha256.Sum(msg) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + signed, err := SignRSAPKCS1v15(priv, 0, msg) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + hashed := sha256.Sum(msg) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + if err == nil { + t.Fatal("error expected") + } +} + +func TestSignVerifyRSAPSS(t *testing.T) { + sha1 := NewSHA1() + priv, pub := newRSAKey(t, 2048) + sha1.Sum([]byte("testing")) + hashed := sha1.Sum(nil) + signed, err := SignRSAPSS(priv, crypto.SHA1, hashed, 0) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPSS(pub, crypto.SHA1, hashed, signed, 0) + if err != nil { + t.Fatal(err) + } +} + +func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { + t.Helper() + N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + if err != nil { + t.Errorf("GenerateKeyRSA(%d): %v", size, err) + } + priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + if err != nil { + t.Errorf("NewPrivateKeyRSA(%d): %v", size, err) + } + pub, err := NewPublicKeyRSA(N, E) + if err != nil { + t.Errorf("NewPublicKeyRSA(%d): %v", size, err) + } + return priv, pub +} diff --git a/openssl/sha_test.go b/openssl/sha_test.go new file mode 100644 index 0000000..c697816 --- /dev/null +++ b/openssl/sha_test.go @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +import ( + "bytes" + "encoding" + "hash" + "testing" +) + +func TestSha(t *testing.T) { + msg := []byte("testig") + for i, fn := range []func() hash.Hash{NewSHA1, NewSHA224, NewSHA256, NewSHA384, NewSHA512} { + h := fn() + initSum := h.Sum(nil) + n, err := h.Write(msg) + if err != nil { + t.Errorf("i %d: %v", i, err) + continue + } + if n != len(msg) { + t.Errorf("i %d: got: %d, want: %d", i, n, len(msg)) + } + sum := h.Sum(nil) + if size := h.Size(); len(sum) != size { + t.Errorf("i %d: got: %d, want: %d", i, len(sum), size) + } + if bytes.Equal(sum, initSum) { + t.Errorf("i %d: Write didn't change internal hash state", i) + } + + state, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Errorf("i: %d: could not marshal: %v", i, err) + } + h2 := fn() + if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { + t.Errorf("i: %d: could not unmarshal: %v", i, err) + } + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("i: %d = 0x%x != marshaled 0x%x", i, actual, actual2) + } + + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("i %d: got:%x want:%x", i, sum, initSum) + } + } +} From 1c0e4d9e73935efada05b93af3387962d9ecb2da Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Tue, 25 Jan 2022 20:38:58 +0100 Subject: [PATCH 2/6] Update openssl/ecdsa_test.go Co-authored-by: Davis Goodin --- openssl/ecdsa_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openssl/ecdsa_test.go b/openssl/ecdsa_test.go index c85c706..132a488 100644 --- a/openssl/ecdsa_test.go +++ b/openssl/ecdsa_test.go @@ -75,7 +75,7 @@ func testECDSASignAndVerify(t *testing.T, c elliptic.Curve) { } hashed[0] ^= 0xff if VerifyECDSA(pub, hashed, r, s) { - t.Errorf("Verify always works!") + t.Errorf("Verify succeeded despite intentionally invalid hash!") } } From 03434b1e8b40ef6294ba3f486f5996796f61df0c Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Tue, 25 Jan 2022 19:44:08 +0000 Subject: [PATCH 3/6] use bits.OnesCount --- openssl/aes_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openssl/aes_test.go b/openssl/aes_test.go index 8f19012..6047c87 100644 --- a/openssl/aes_test.go +++ b/openssl/aes_test.go @@ -9,6 +9,7 @@ package openssl import ( "bytes" "crypto/cipher" + "math/bits" "testing" ) @@ -112,8 +113,7 @@ func TestSealPanic(t *testing.T) { // maxInt is implemented as math.MaxInt, but this constant // is only available since go1.17. // TODO: use math.MaxInt once go1.16 is no longer supported. - const intSize = 32 << (^uint(0) >> 63) // 32 or 64 - maxInt := 1<<(intSize-1) - 1 + maxInt := bits.OnesCount(^uint(0)) gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, maxInt), nil) }) } From 19b1594e4e3b9315d712a7458be3d7d59becc38e Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Tue, 25 Jan 2022 19:54:00 +0000 Subject: [PATCH 4/6] simplify MaxInt --- openssl/aes_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/openssl/aes_test.go b/openssl/aes_test.go index 6047c87..8428c39 100644 --- a/openssl/aes_test.go +++ b/openssl/aes_test.go @@ -9,7 +9,6 @@ package openssl import ( "bytes" "crypto/cipher" - "math/bits" "testing" ) @@ -113,7 +112,7 @@ func TestSealPanic(t *testing.T) { // maxInt is implemented as math.MaxInt, but this constant // is only available since go1.17. // TODO: use math.MaxInt once go1.16 is no longer supported. - maxInt := bits.OnesCount(^uint(0)) + maxInt := int((^uint(0)) >> 1) gcm.Seal(nil, make([]byte, gcmStandardNonceSize), make([]byte, maxInt), nil) }) } From bc735d49fb1659d1b357e58b663a26ebb62d980a Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Tue, 1 Feb 2022 14:25:41 +0000 Subject: [PATCH 5/6] use tabular test format --- openssl/hmac_test.go | 81 +++++++++++++++++++++++++------------------- openssl/rsa_test.go | 45 ++++++++++++------------ openssl/sha_test.go | 80 ++++++++++++++++++++++++------------------- 3 files changed, 114 insertions(+), 92 deletions(-) diff --git a/openssl/hmac_test.go b/openssl/hmac_test.go index 5cbe830..1b6a6e8 100644 --- a/openssl/hmac_test.go +++ b/openssl/hmac_test.go @@ -13,39 +13,52 @@ import ( ) func TestHMAC(t *testing.T) { - for i, fn := range []func() hash.Hash{NewSHA1, NewSHA224, NewSHA256, NewSHA384, NewSHA512} { - h := NewHMAC(fn, nil) - h.Write([]byte("hello")) - sumHello := h.Sum(nil) - - h = NewHMAC(fn, nil) - h.Write([]byte("hello world")) - sumHelloWorld := h.Sum(nil) - - // Test that Sum has no effect on future Sum or Write operations. - // This is a bit unusual as far as usage, but it's allowed - // by the definition of Go hash.Hash, and some clients expect it to work. - h = NewHMAC(fn, nil) - h.Write([]byte("hello")) - if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { - t.Fatalf("i %d: 1st Sum after hello = %x, want %x", i, sum, sumHello) - } - if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { - t.Fatalf("i %d: 2nd Sum after hello = %x, want %x", i, sum, sumHello) - } - - h.Write([]byte(" world")) - if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { - t.Fatalf("i %d: 1st Sum after hello world = %x, want %x", i, sum, sumHelloWorld) - } - if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { - t.Fatalf("i %d: 2nd Sum after hello world = %x, want %x", i, sum, sumHelloWorld) - } - - h.Reset() - h.Write([]byte("hello")) - if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { - t.Fatalf("i %d: Sum after Reset + hello = %x, want %x", i, sum, sumHello) - } + var tests = []struct { + name string + fn func() hash.Hash + }{ + {"sha1", NewSHA1}, + {"sha224", NewSHA224}, + {"sha256", NewSHA256}, + {"sha384", NewSHA384}, + {"sha512", NewSHA512}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := NewHMAC(tt.fn, nil) + h.Write([]byte("hello")) + sumHello := h.Sum(nil) + + h = NewHMAC(tt.fn, nil) + h.Write([]byte("hello world")) + sumHelloWorld := h.Sum(nil) + + // Test that Sum has no effect on future Sum or Write operations. + // This is a bit unusual as far as usage, but it's allowed + // by the definition of Go hash.Hash, and some clients expect it to work. + h = NewHMAC(tt.fn, nil) + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("1st Sum after hello = %x, want %x", sum, sumHello) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("2nd Sum after hello = %x, want %x", sum, sumHello) + } + + h.Write([]byte(" world")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("1st Sum after hello world = %x, want %x", sum, sumHelloWorld) + } + if sum := h.Sum(nil); !bytes.Equal(sum, sumHelloWorld) { + t.Fatalf("2nd Sum after hello world = %x, want %x", sum, sumHelloWorld) + } + + h.Reset() + h.Write([]byte("hello")) + if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) { + t.Fatalf("Sum after Reset + hello = %x, want %x", sum, sumHello) + } + }) } } diff --git a/openssl/rsa_test.go b/openssl/rsa_test.go index a6e360b..d3610d0 100644 --- a/openssl/rsa_test.go +++ b/openssl/rsa_test.go @@ -9,31 +9,28 @@ package openssl import ( "bytes" "crypto" + "strconv" "testing" ) func TestRSAKeyGeneration(t *testing.T) { - for _, size := range []int{128, 1024, 2048, 3072} { - priv, pub := newRSAKey(t, size) - testRSAKeyBasics(t, priv, pub) - } -} - -func testRSAKeyBasics(t *testing.T, priv *PrivateKeyRSA, pub *PublicKeyRSA) { - // Cannot call encrypt/decrypt directly. Test via PKCS1v15. - msg := []byte("hi!") - enc, err := EncryptRSAPKCS1(pub, msg) - if err != nil { - t.Errorf("EncryptPKCS1v15: %v", err) - return - } - dec, err := DecryptRSAPKCS1(priv, enc) - if err != nil { - t.Errorf("DecryptPKCS1v15: %v", err) - return - } - if !bytes.Equal(dec, msg) { - t.Errorf("got:%x want:%x", dec, msg) + for _, size := range []int{2048, 3072} { + t.Run(strconv.Itoa(size), func(t *testing.T) { + t.Parallel() + priv, pub := newRSAKey(t, size) + msg := []byte("hi!") + enc, err := EncryptRSAPKCS1(pub, msg) + if err != nil { + t.Fatalf("EncryptPKCS1v15: %v", err) + } + dec, err := DecryptRSAPKCS1(priv, enc) + if err != nil { + t.Fatalf("DecryptPKCS1v15: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("got:%x want:%x", dec, msg) + } + }) } } @@ -134,15 +131,15 @@ func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { t.Helper() N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) if err != nil { - t.Errorf("GenerateKeyRSA(%d): %v", size, err) + t.Fatalf("GenerateKeyRSA(%d): %v", size, err) } priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) if err != nil { - t.Errorf("NewPrivateKeyRSA(%d): %v", size, err) + t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) } pub, err := NewPublicKeyRSA(N, E) if err != nil { - t.Errorf("NewPublicKeyRSA(%d): %v", size, err) + t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) } return priv, pub } diff --git a/openssl/sha_test.go b/openssl/sha_test.go index c697816..ef43037 100644 --- a/openssl/sha_test.go +++ b/openssl/sha_test.go @@ -15,41 +15,53 @@ import ( func TestSha(t *testing.T) { msg := []byte("testig") - for i, fn := range []func() hash.Hash{NewSHA1, NewSHA224, NewSHA256, NewSHA384, NewSHA512} { - h := fn() - initSum := h.Sum(nil) - n, err := h.Write(msg) - if err != nil { - t.Errorf("i %d: %v", i, err) - continue - } - if n != len(msg) { - t.Errorf("i %d: got: %d, want: %d", i, n, len(msg)) - } - sum := h.Sum(nil) - if size := h.Size(); len(sum) != size { - t.Errorf("i %d: got: %d, want: %d", i, len(sum), size) - } - if bytes.Equal(sum, initSum) { - t.Errorf("i %d: Write didn't change internal hash state", i) - } + var tests = []struct { + name string + fn func() hash.Hash + }{ + {"sha1", NewSHA1}, + {"sha224", NewSHA224}, + {"sha256", NewSHA256}, + {"sha384", NewSHA384}, + {"sha512", NewSHA512}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := tt.fn() + initSum := h.Sum(nil) + n, err := h.Write(msg) + if err != nil { + t.Fatal(err) + } + if n != len(msg) { + t.Errorf("got: %d, want: %d", n, len(msg)) + } + sum := h.Sum(nil) + if size := h.Size(); len(sum) != size { + t.Errorf("got: %d, want: %d", len(sum), size) + } + if bytes.Equal(sum, initSum) { + t.Error("Write didn't change internal hash state") + } - state, err := h.(encoding.BinaryMarshaler).MarshalBinary() - if err != nil { - t.Errorf("i: %d: could not marshal: %v", i, err) - } - h2 := fn() - if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { - t.Errorf("i: %d: could not unmarshal: %v", i, err) - } - if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { - t.Errorf("i: %d = 0x%x != marshaled 0x%x", i, actual, actual2) - } + state, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Errorf("could not marshal: %v", err) + } + h2 := tt.fn() + if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { + t.Errorf("could not unmarshal: %v", err) + } + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("0x%x != marshaled 0x%x", actual, actual2) + } - h.Reset() - sum = h.Sum(nil) - if !bytes.Equal(sum, initSum) { - t.Errorf("i %d: got:%x want:%x", i, sum, initSum) - } + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } + }) } } From c1bc5df91e550db1e8268c9244b1eb866d94abde Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Fri, 4 Feb 2022 10:40:58 +0000 Subject: [PATCH 6/6] fix rsa tests --- openssl/rsa_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/openssl/rsa_test.go b/openssl/rsa_test.go index d3610d0..a9849f8 100644 --- a/openssl/rsa_test.go +++ b/openssl/rsa_test.go @@ -71,9 +71,9 @@ func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { func TestSignVerifyPKCS1v15(t *testing.T) { sha256 := NewSHA256() - msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - hashed := sha256.Sum(msg) + sha256.Write([]byte("hi!")) + hashed := sha256.Sum(nil) signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) @@ -101,7 +101,8 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { sha256 := NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - hashed := sha256.Sum(msg) + sha256.Write(msg) + hashed := sha256.Sum(nil) signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) @@ -115,7 +116,7 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { func TestSignVerifyRSAPSS(t *testing.T) { sha1 := NewSHA1() priv, pub := newRSAKey(t, 2048) - sha1.Sum([]byte("testing")) + sha1.Write([]byte("testing")) hashed := sha1.Sum(nil) signed, err := SignRSAPSS(priv, crypto.SHA1, hashed, 0) if err != nil {