From 7ab2c45f2514b34d5b3815d1ea23dd4d79de1bec Mon Sep 17 00:00:00 2001 From: Tsachi Herman Date: Tue, 8 Feb 2022 18:53:39 -0500 Subject: [PATCH] Use context instead of an atomic var. --- crypto/merklesignature/keysBuilder.go | 15 ++++++------ .../merkleSignatureScheme_test.go | 24 +++++++++---------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/crypto/merklesignature/keysBuilder.go b/crypto/merklesignature/keysBuilder.go index 4d69fc30e0..93553e7d5e 100644 --- a/crypto/merklesignature/keysBuilder.go +++ b/crypto/merklesignature/keysBuilder.go @@ -17,9 +17,9 @@ package merklesignature import ( + "context" "runtime" "sync" - "sync/atomic" "github.com/algorand/go-algorand/crypto" ) @@ -28,7 +28,8 @@ import ( func KeysBuilder(numberOfKeys uint64) ([]crypto.FalconSigner, error) { numOfKeysPerRoutine, _ := calculateRanges(numberOfKeys) - var terminate int64 + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() errors := make(chan error, 1) defer close(errors) @@ -48,13 +49,13 @@ func KeysBuilder(numberOfKeys uint64) ([]crypto.FalconSigner, error) { wg.Add(1) go func(startIdx, endIdx uint64, keys []crypto.FalconSigner) { defer wg.Done() - if err := generateKeysForRange(startIdx, endIdx, &terminate, keys); err != nil { + if err := generateKeysForRange(ctx, startIdx, endIdx, keys); err != nil { // write to the error channel, if it's not full already. select { case errors <- err: default: } - atomic.StoreInt64(&terminate, 1) + ctxCancel() } }(i, endIdx, keys) } @@ -80,10 +81,10 @@ func calculateRanges(numberOfKeys uint64) (numOfKeysPerRoutine uint64, numOfRout return } -func generateKeysForRange(startIdx uint64, endIdx uint64, terminate *int64, keys []crypto.FalconSigner) error { +func generateKeysForRange(ctx context.Context, startIdx uint64, endIdx uint64, keys []crypto.FalconSigner) error { for k := startIdx; k < endIdx; k++ { - if atomic.LoadInt64(terminate) != 0 { - return nil + if ctx.Err() != nil { + break } sigAlgo, err := crypto.NewFalconSigner() if err != nil { diff --git a/crypto/merklesignature/merkleSignatureScheme_test.go b/crypto/merklesignature/merkleSignatureScheme_test.go index f95c3534a4..7f2a0600b6 100644 --- a/crypto/merklesignature/merkleSignatureScheme_test.go +++ b/crypto/merklesignature/merkleSignatureScheme_test.go @@ -18,6 +18,7 @@ package merklesignature import ( "crypto/rand" + "errors" "math" "testing" @@ -240,7 +241,7 @@ func TestSigning(t *testing.T) { err = signer.GetVerifier().Verify(start+5, hashable, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) signer = generateTestSigner(50, 100, 12, a) a.Equal(4, length(signer, a)) @@ -274,16 +275,16 @@ func TestBadRound(t *testing.T) { err := signer.GetVerifier().Verify(start+1, hashable, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) hashable, sig = makeSig(signer, start+1, a) err = signer.GetVerifier().Verify(start, hashable, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) err = signer.GetVerifier().Verify(start+2, hashable, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) } func TestBadMerkleProofInSignature(t *testing.T) { @@ -297,7 +298,7 @@ func TestBadMerkleProofInSignature(t *testing.T) { sig2.Proof.Path = sig2.Proof.Path[:len(sig2.Proof.Path)-1] err := signer.GetVerifier().Verify(start, hashable, sig2) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) sig3 := copySig(sig) someDigest := crypto.Digest{} @@ -305,7 +306,7 @@ func TestBadMerkleProofInSignature(t *testing.T) { sig3.Proof.Path[0] = someDigest[:] err = signer.GetVerifier().Verify(start, hashable, sig3) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) } func copySig(sig Signature) Signature { @@ -334,7 +335,7 @@ func TestIncorrectByteSignature(t *testing.T) { err := signer.GetVerifier().Verify(start, hashable, sig2) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) } func TestIncorrectMerkleIndex(t *testing.T) { @@ -352,17 +353,16 @@ func TestIncorrectMerkleIndex(t *testing.T) { sig.MerkleArrayIndex = 0 err = signer.GetVerifier().Verify(20, h, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) sig.MerkleArrayIndex = math.MaxUint64 err = signer.GetVerifier().Verify(20, h, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) err = signer.GetVerifier().Verify(20, h, sig) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) - + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) } func TestAttemptToUseDifferentKey(t *testing.T) { @@ -382,7 +382,7 @@ func TestAttemptToUseDifferentKey(t *testing.T) { err := signer.GetVerifier().Verify(start+1, hashable, sig2) a.Error(err) - a.Contains(err.Error(), ErrSignatureSchemeVerificationFailed) + a.True(errors.Is(err, ErrSignatureSchemeVerificationFailed)) } func TestMarshal(t *testing.T) {