Skip to content

Commit

Permalink
Added ctx parameter to Sign and Verify interfaces to allow context pa…
Browse files Browse the repository at this point in the history
…ssing in the case of remote calls to sign/verify an envelope (i.e. a KMS call)
  • Loading branch information
khalkie committed Feb 10, 2023
1 parent 154aa5b commit 60bd7fd
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
11 changes: 6 additions & 5 deletions dsse/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ https://github.com/secure-systems-lab/dsse
package dsse

import (
"context"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -77,7 +78,7 @@ using the current algorithm, and the key used (if applicable).
For an example see EcdsaSigner in sign_test.go.
*/
type Signer interface {
Sign(data []byte) ([]byte, error)
Sign(ctx context.Context, data []byte) ([]byte, error)
KeyID() (string, error)
}

Expand Down Expand Up @@ -143,7 +144,7 @@ Returned is an envelope as defined here:
https://github.com/secure-systems-lab/dsse/blob/master/envelope.md
One signature will be added for each Signer in the EnvelopeSigner.
*/
func (es *EnvelopeSigner) SignPayload(payloadType string, body []byte) (*Envelope, error) {
func (es *EnvelopeSigner) SignPayload(ctx context.Context, payloadType string, body []byte) (*Envelope, error) {
var e = Envelope{
Payload: base64.StdEncoding.EncodeToString(body),
PayloadType: payloadType,
Expand All @@ -152,7 +153,7 @@ func (es *EnvelopeSigner) SignPayload(payloadType string, body []byte) (*Envelop
paeEnc := PAE(payloadType, body)

for _, signer := range es.providers {
sig, err := signer.Sign(paeEnc)
sig, err := signer.Sign(ctx, paeEnc)
if err != nil {
return nil, err
}
Expand All @@ -176,8 +177,8 @@ Any domain specific validation such as parsing the decoded body and
validating the payload type is left out to the caller.
Verify returns a list of accepted keys each including a keyid, public and signiture of the accepted provider keys.
*/
func (es *EnvelopeSigner) Verify(e *Envelope) ([]AcceptedKey, error) {
return es.ev.Verify(e)
func (es *EnvelopeSigner) Verify(ctx context.Context, e *Envelope) ([]AcceptedKey, error) {
return es.ev.Verify(ctx, e)
}

/*
Expand Down
81 changes: 41 additions & 40 deletions dsse/sign_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dsse

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
Expand Down Expand Up @@ -40,11 +41,11 @@ func TestPAE(t *testing.T) {

type nilsigner int

func (n nilsigner) Sign(data []byte) ([]byte, error) {
func (n nilsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n nilsigner) Verify(data, sig []byte) error {
func (n nilsigner) Verify(ctx context.Context, data, sig []byte) error {
if len(data) != len(sig) {
return errLength
}
Expand All @@ -68,11 +69,11 @@ func (n nilsigner) Public() crypto.PublicKey {

type nullsigner int

func (n nullsigner) Sign(data []byte) ([]byte, error) {
func (n nullsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n nullsigner) Verify(data, sig []byte) error {
func (n nullsigner) Verify(ctx context.Context, data, sig []byte) error {
if len(data) != len(sig) {
return errLength
}
Expand All @@ -96,11 +97,11 @@ func (n nullsigner) Public() crypto.PublicKey {

type errsigner int

func (n errsigner) Sign(data []byte) ([]byte, error) {
func (n errsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return nil, fmt.Errorf("signing error")
}

func (n errsigner) Verify(data, sig []byte) error {
func (n errsigner) Verify(ctx context.Context, data, sig []byte) error {
return errVerify
}

Expand All @@ -117,11 +118,11 @@ type errverifier int
var errVerify = fmt.Errorf("accepted signatures do not match threshold, Found: 0, Expected 1")
var errThreshold = fmt.Errorf("invalid threshold")

func (n errverifier) Sign(data []byte) ([]byte, error) {
func (n errverifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n errverifier) Verify(data, sig []byte) error {
func (n errverifier) Verify(ctx context.Context, data, sig []byte) error {
return errVerify
}

Expand All @@ -135,11 +136,11 @@ func (n errverifier) Public() crypto.PublicKey {

type badverifier int

func (n badverifier) Sign(data []byte) ([]byte, error) {
func (n badverifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
return append(data, byte(0)), nil
}

func (n badverifier) Verify(data, sig []byte) error {
func (n badverifier) Verify(ctx context.Context, data, sig []byte) error {

if len(data) != len(sig) {
return errLength
Expand Down Expand Up @@ -199,7 +200,7 @@ func TestNilSign(t *testing.T) {
signer, err := NewEnvelopeSigner(ns)
assert.Nil(t, err, "unexpected error")

got, err := signer.SignPayload(payloadType, []byte(payload))
got, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")
assert.Equal(t, &want, got, "bad signature")
}
Expand All @@ -209,7 +210,7 @@ func TestSignError(t *testing.T) {
signer, err := NewEnvelopeSigner(es)
assert.Nil(t, err, "unexpected error")

got, err := signer.SignPayload("t", []byte("d"))
got, err := signer.SignPayload(context.TODO(), "t", []byte("d"))
assert.Nil(t, got, "expected nil")
assert.NotNil(t, err, "error expected")
assert.Equal(t, "signing error", err.Error(), "wrong error")
Expand Down Expand Up @@ -252,7 +253,7 @@ type EcdsaSigner struct {
verified bool
}

func (es *EcdsaSigner) Sign(data []byte) ([]byte, error) {
func (es *EcdsaSigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
// Data is complete message, hash it and sign the digest
digest := sha256.Sum256(data)
r, s, err := rfc6979.SignECDSA(es.key, digest[:], sha256.New)
Expand All @@ -268,7 +269,7 @@ func (es *EcdsaSigner) Sign(data []byte) ([]byte, error) {
return rawSig, nil
}

func (es *EcdsaSigner) Verify(data, sig []byte) error {
func (es *EcdsaSigner) Verify(ctx context.Context, data, sig []byte) error {
var r big.Int
var s big.Int
digest := sha256.Sum256(data)
Expand Down Expand Up @@ -319,12 +320,12 @@ func TestEcdsaSign(t *testing.T) {
signer, err := NewEnvelopeSigner(ecdsa)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "unexpected error")
assert.Equal(t, &want, env, "Wrong envelope generated")

// Now verify
acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.True(t, ecdsa.verified, "verify was not called")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
Expand Down Expand Up @@ -384,10 +385,10 @@ func TestVerifyOneProvider(t *testing.T) {
signer, err := NewEnvelopeSigner(ns)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
assert.Equal(t, acceptedKeys[0].KeyID, "nil", "unexpected keyid")
Expand All @@ -402,10 +403,10 @@ func TestVerifyMultipleProvider(t *testing.T) {
signer, err := NewEnvelopeSigner(ns, null)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 2, "unexpected keys")
}
Expand All @@ -418,10 +419,10 @@ func TestVerifyMultipleProviderThreshold(t *testing.T) {
var null nullsigner
signer, err := NewMultiEnvelopeSigner(2, ns, null)
assert.Nil(t, err)
env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 2, "unexpected keys")
}
Expand All @@ -443,10 +444,10 @@ func TestVerifyErr(t *testing.T) {
signer, err := NewEnvelopeSigner(errv)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.Equal(t, errVerify, err, "wrong error")
}

Expand All @@ -458,10 +459,10 @@ func TestBadVerifier(t *testing.T) {
signer, err := NewEnvelopeSigner(badv)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.NotNil(t, err, "expected error")
}

Expand All @@ -472,7 +473,7 @@ func TestVerifyNoSig(t *testing.T) {

env := &Envelope{}

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.Equal(t, ErrNoSignature, err, "wrong error")
}

Expand All @@ -489,7 +490,7 @@ func TestVerifyBadBase64(t *testing.T) {
},
}

_, err := signer.Verify(env)
_, err := signer.Verify(context.TODO(), env)
assert.IsType(t, base64.CorruptInputError(0), err, "wrong error")
})

Expand All @@ -503,7 +504,7 @@ func TestVerifyBadBase64(t *testing.T) {
},
}

_, err := signer.Verify(env)
_, err := signer.Verify(context.TODO(), env)
assert.IsType(t, base64.CorruptInputError(0), err, "wrong error")
})
}
Expand All @@ -527,7 +528,7 @@ func TestVerifyNoMatch(t *testing.T) {
},
}

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.NotNil(t, err, "expected error")
}

Expand All @@ -537,11 +538,11 @@ type interceptSigner struct {
verifyCalled bool
}

func (i *interceptSigner) Sign(data []byte) ([]byte, error) {
func (i *interceptSigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (i *interceptSigner) Verify(data, sig []byte) error {
func (i *interceptSigner) Verify(ctx context.Context, data, sig []byte) error {
i.verifyCalled = true

if i.verifyRes {
Expand Down Expand Up @@ -573,10 +574,10 @@ func TestVerifyOneFail(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.True(t, s1.verifyCalled, "verify not called")
assert.True(t, s2.verifyCalled, "verify not called")
Expand All @@ -599,10 +600,10 @@ func TestVerifySameKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.True(t, s1.verifyCalled, "verify not called")
assert.True(t, s2.verifyCalled, "verify not called")
Expand All @@ -627,10 +628,10 @@ func TestVerifyEmptyKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
// assert.True(t, s1.verifyCalled, "verify not called")
// assert.True(t, s2.verifyCalled, "verify not called")
Expand Down Expand Up @@ -658,10 +659,10 @@ func TestVerifyPublicKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
assert.Equal(t, acceptedKeys[0].KeyID, keyID, "unexpected keyid")
Expand Down
7 changes: 4 additions & 3 deletions dsse/verify.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dsse

import (
"context"
"crypto"
"errors"
"fmt"
Expand All @@ -15,7 +16,7 @@ must perform the same steps.
If KeyID returns successfully, only signature matching the key ID will be verified.
*/
type Verifier interface {
Verify(data, sig []byte) error
Verify(ctx context.Context, data, sig []byte) error
KeyID() (string, error)
Public() crypto.PublicKey
}
Expand All @@ -31,7 +32,7 @@ type AcceptedKey struct {
Sig Signature
}

func (ev *EnvelopeVerifier) Verify(e *Envelope) ([]AcceptedKey, error) {
func (ev *EnvelopeVerifier) Verify(ctx context.Context, e *Envelope) ([]AcceptedKey, error) {
if e == nil {
return nil, errors.New("cannot verify a nil envelope")
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (ev *EnvelopeVerifier) Verify(e *Envelope) ([]AcceptedKey, error) {
continue
}

err = v.Verify(paeEnc, sig)
err = v.Verify(ctx, paeEnc, sig)
if err != nil {
continue
}
Expand Down
Loading

0 comments on commit 60bd7fd

Please sign in to comment.