diff --git a/pkg/cosign/tlog.go b/pkg/cosign/tlog.go index 827311993c2..83d6f61f179 100644 --- a/pkg/cosign/tlog.go +++ b/pkg/cosign/tlog.go @@ -208,7 +208,7 @@ func doUpload(ctx context.Context, rekorClient *client.Rekor, pe models.Proposed // Here, we display the proof and succeed. var existsErr *entries.CreateLogEntryConflict if errors.As(err, &existsErr) { - ui.Infof(ctx, "Signature already exists. Displaying proof") + ui.Infof(ctx, "Signature already exists. Fetching and verifying inclusion proof.") uriSplit := strings.Split(existsErr.Location.String(), "/") uuid := uriSplit[len(uriSplit)-1] e, err := GetTlogEntry(ctx, rekorClient, uuid) @@ -299,8 +299,8 @@ func getTreeUUID(entryUUID string) (string, error) { } } -// Validates UUID and also TreeID if present. -func isExpectedResponseUUID(requestEntryUUID string, responseEntryUUID string, treeid string) error { +// Validates UUID and also shard if present. +func isExpectedResponseUUID(requestEntryUUID string, responseEntryUUID string) error { // Comparare UUIDs requestUUID, err := getUUID(requestEntryUUID) if err != nil { @@ -313,19 +313,21 @@ func isExpectedResponseUUID(requestEntryUUID string, responseEntryUUID string, t if requestUUID != responseUUID { return fmt.Errorf("expected EntryUUID %s got UUID %s", requestEntryUUID, responseEntryUUID) } - // Compare tree ID if it is in the request. - requestTreeID, err := getTreeUUID(requestEntryUUID) + // Compare shards if it is in the request. + requestShardID, err := getTreeUUID(requestEntryUUID) if err != nil { return err } - if requestTreeID != "" { - tid, err := getTreeUUID(treeid) - if err != nil { - return err - } - if requestTreeID != tid { - return fmt.Errorf("expected EntryUUID %s got UUID %s from Tree %s", requestEntryUUID, responseEntryUUID, treeid) - } + responseShardID, err := getTreeUUID(responseEntryUUID) + if err != nil { + return err + } + // no shard ID prepends the entry UUID + if requestShardID == "" || responseShardID == "" { + return nil + } + if requestShardID != responseShardID { + return fmt.Errorf("expected UUID %s from shard %s: got UUID %s from shard %s", requestEntryUUID, responseEntryUUID, requestShardID, responseShardID) } return nil } @@ -357,8 +359,8 @@ func GetTlogEntry(ctx context.Context, rekorClient *client.Rekor, entryUUID stri return nil, err } for k, e := range resp.Payload { - // Validate that request EntryUUID matches the response UUID and response Tree ID - if err := isExpectedResponseUUID(entryUUID, k, *e.LogID); err != nil { + // Validate that request EntryUUID matches the response UUID and response shard ID + if err := isExpectedResponseUUID(entryUUID, k); err != nil { return nil, fmt.Errorf("unexpected entry returned from rekor server: %w", err) } // Check that body hash matches UUID diff --git a/pkg/cosign/tlog_test.go b/pkg/cosign/tlog_test.go index 08a9d6bca54..55bd76bb9f6 100644 --- a/pkg/cosign/tlog_test.go +++ b/pkg/cosign/tlog_test.go @@ -67,14 +67,12 @@ func TestExpectedRekorResponse(t *testing.T) { name string requestUUID string responseUUID string - treeID string wantErr bool }{ { name: "valid match with request & response entry UUID", requestUUID: validTreeID + validUUID, responseUUID: validTreeID + validUUID, - treeID: validTreeID, wantErr: false, }, // The following is the current typical Rekor behavior. @@ -82,63 +80,54 @@ func TestExpectedRekorResponse(t *testing.T) { name: "valid match with request entry UUID", requestUUID: validTreeID + validUUID, responseUUID: validUUID, - treeID: validTreeID, wantErr: false, }, { name: "valid match with request UUID", requestUUID: validUUID, responseUUID: validUUID, - treeID: validTreeID, wantErr: false, }, { name: "valid match with response entry UUID", requestUUID: validUUID, responseUUID: validTreeID + validUUID, - treeID: validTreeID, wantErr: false, }, { name: "mismatch uuid with response tree id", requestUUID: validUUID, responseUUID: validTreeID + validUUID1, - treeID: validTreeID, wantErr: true, }, { name: "mismatch uuid with request tree id", requestUUID: validTreeID + validUUID1, responseUUID: validUUID, - treeID: validTreeID, wantErr: true, }, { name: "mismatch tree id", requestUUID: validTreeID + validUUID, - responseUUID: validUUID, - treeID: validTreeID1, + responseUUID: validTreeID1 + validUUID, wantErr: true, }, { name: "invalid response tree id", requestUUID: validTreeID + validUUID, responseUUID: invalidTreeID + validUUID, - treeID: invalidTreeID, wantErr: true, }, { name: "invalid request tree id", requestUUID: invalidTreeID + validUUID, responseUUID: validUUID, - treeID: invalidTreeID, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := isExpectedResponseUUID(tt.requestUUID, - tt.responseUUID, tt.treeID); (got != nil) != tt.wantErr { + if got := isExpectedResponseUUID(tt.requestUUID, tt.responseUUID); (got != nil) != tt.wantErr { t.Errorf("isExpectedResponseUUID() = %v, want %v", got, tt.wantErr) } })