diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go index 26d1549b59..c4e4498007 100644 --- a/sha3/sha3_test.go +++ b/sha3/sha3_test.go @@ -27,14 +27,6 @@ const ( katFilename = "testdata/keccakKats.json.deflate" ) -// Internal-use instances of SHAKE used to test against KATs. -func newHashShake128() hash.Hash { - return &state{rate: 168, dsbyte: 0x1f, outputLen: 512} -} -func newHashShake256() hash.Hash { - return &state{rate: 136, dsbyte: 0x1f, outputLen: 512} -} - // testDigests contains functions returning hash.Hash instances // with output-length equal to the KAT length for SHA-3, Keccak // and SHAKE instances. @@ -45,15 +37,20 @@ var testDigests = map[string]func() hash.Hash{ "SHA3-512": New512, "Keccak-256": NewLegacyKeccak256, "Keccak-512": NewLegacyKeccak512, - "SHAKE128": newHashShake128, - "SHAKE256": newHashShake256, } -// testShakes contains functions that return ShakeHash instances for -// testing the ShakeHash-specific interface. -var testShakes = map[string]func() ShakeHash{ - "SHAKE128": NewShake128, - "SHAKE256": NewShake256, +// testShakes contains functions that return sha3.ShakeHash instances for +// with output-length equal to the KAT length. +var testShakes = map[string]struct { + constructor func(N []byte, S []byte) ShakeHash + defAlgoName string + defCustomStr string +}{ + // NewCShake without customization produces same result as SHAKE + "SHAKE128": {NewCShake128, "", ""}, + "SHAKE256": {NewCShake256, "", ""}, + "cSHAKE128": {NewCShake128, "CSHAKE128", "CustomStrign"}, + "cSHAKE256": {NewCShake256, "CSHAKE256", "CustomStrign"}, } // decodeHex converts a hex-encoded string into a raw byte string. @@ -71,6 +68,10 @@ type KeccakKats struct { Digest string `json:"digest"` Length int64 `json:"length"` Message string `json:"message"` + + // Defined only for cSHAKE + N string `json:"N"` + S string `json:"S"` } } @@ -103,10 +104,9 @@ func TestKeccakKats(t *testing.T) { t.Errorf("error decoding KATs: %s", err) } - // Do the KATs. - for functionName, kats := range katSet.Kats { - d := testDigests[functionName]() - for _, kat := range kats { + for algo, function := range testDigests { + d := function() + for _, kat := range katSet.Kats[algo] { d.Reset() in, err := hex.DecodeString(kat.Message) if err != nil { @@ -115,8 +115,39 @@ func TestKeccakKats(t *testing.T) { d.Write(in[:kat.Length/8]) got := strings.ToUpper(hex.EncodeToString(d.Sum(nil))) if got != kat.Digest { - t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n %s\ngot:\n %s\nwanted:\n %s", - functionName, impl, kat.Length, kat.Message, got, kat.Digest) + t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n %s\ngot:\n %s\nwanted:\n %s", + algo, impl, kat.Length, kat.Message, got, kat.Digest) + t.Logf("wanted %+v", kat) + t.FailNow() + } + continue + } + } + + for algo, v := range testShakes { + for _, kat := range katSet.Kats[algo] { + N, err := hex.DecodeString(kat.N) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } + + S, err := hex.DecodeString(kat.S) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } + d := v.constructor(N, S) + in, err := hex.DecodeString(kat.Message) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } + + d.Write(in[:kat.Length/8]) + out := make([]byte, len(kat.Digest)/2) + d.Read(out) + got := strings.ToUpper(hex.EncodeToString(out)) + if got != kat.Digest { + t.Errorf("function=%s, implementation=%s, length=%d N:%s\n S:%s\nmessage:\n %s \ngot:\n %s\nwanted:\n %s", + algo, impl, kat.Length, kat.N, kat.S, kat.Message, got, kat.Digest) t.Logf("wanted %+v", kat) t.FailNow() } @@ -184,6 +215,34 @@ func TestUnalignedWrite(t *testing.T) { t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want) } } + + // Same for SHAKE + for alg, df := range testShakes { + want := make([]byte, 16) + got := make([]byte, 16) + d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr)) + + d.Reset() + d.Write(buf) + d.Read(want) + d.Reset() + for i := 0; i < len(buf); { + // Cycle through offsets which make a 137 byte sequence. + // Because 137 is prime this sequence should exercise all corner cases. + offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1} + for _, j := range offsets { + if v := len(buf) - i; v < j { + j = v + } + d.Write(buf[i : i+j]) + i += j + } + } + d.Read(got) + if !bytes.Equal(got, want) { + t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want) + } + } }) } @@ -225,13 +284,13 @@ func TestAppendNoRealloc(t *testing.T) { // the same output as repeatedly squeezing the instance. func TestSqueezing(t *testing.T) { testUnalignedAndGeneric(t, func(impl string) { - for functionName, newShakeHash := range testShakes { - d0 := newShakeHash() + for algo, v := range testShakes { + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) d0.Write([]byte(testString)) ref := make([]byte, 32) d0.Read(ref) - d1 := newShakeHash() + d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) d1.Write([]byte(testString)) var multiple []byte for range ref { @@ -240,7 +299,7 @@ func TestSqueezing(t *testing.T) { multiple = append(multiple, one...) } if !bytes.Equal(ref, multiple) { - t.Errorf("%s (%s): squeezing %d bytes one at a time failed", functionName, impl, len(ref)) + t.Errorf("%s (%s): squeezing %d bytes one at a time failed", algo, impl, len(ref)) } } }) @@ -255,6 +314,50 @@ func sequentialBytes(size int) []byte { return result } +func TestReset(t *testing.T) { + out1 := make([]byte, 32) + out2 := make([]byte, 32) + + for _, v := range testShakes { + // Calculate hash for the first time + c := v.constructor(nil, []byte{0x99, 0x98}) + c.Write(sequentialBytes(0x100)) + c.Read(out1) + + // Calculate hash again + c.Reset() + c.Write(sequentialBytes(0x100)) + c.Read(out2) + + if !bytes.Equal(out1, out2) { + t.Error("\nExpected:\n", out1, "\ngot:\n", out2) + } + } +} + +func TestClone(t *testing.T) { + out1 := make([]byte, 16) + out2 := make([]byte, 16) + in := sequentialBytes(0x100) + + for _, v := range testShakes { + h1 := v.constructor(nil, []byte{0x01}) + h1.Write([]byte{0x01}) + + h2 := h1.Clone() + + h1.Write(in) + h1.Read(out1) + + h2.Write(in) + h2.Read(out2) + + if !bytes.Equal(out1, out2) { + t.Error("\nExpected:\n", hex.EncodeToString(out1), "\ngot:\n", hex.EncodeToString(out2)) + } + } +} + // BenchmarkPermutationFunction measures the speed of the permutation function // with no input data. func BenchmarkPermutationFunction(b *testing.B) { @@ -341,3 +444,37 @@ func Example_mac() { fmt.Printf("%x\n", h) // Output: 78de2974bd2711d5549ffd32b753ef0f5fa80a0db2556db60f0987eb8a9218ff } + +func ExampleNewCShake256() { + out := make([]byte, 32) + msg := []byte("The quick brown fox jumps over the lazy dog") + + // Example 1: Simple cshake + c1 := NewCShake256([]byte("NAME"), []byte("Partition1")) + c1.Write(msg) + c1.Read(out) + fmt.Println(hex.EncodeToString(out)) + + // Example 2: Different customization string produces different digest + c1 = NewCShake256([]byte("NAME"), []byte("Partition2")) + c1.Write(msg) + c1.Read(out) + fmt.Println(hex.EncodeToString(out)) + + // Example 3: Longer output length produces longer digest + out = make([]byte, 64) + c1 = NewCShake256([]byte("NAME"), []byte("Partition1")) + c1.Write(msg) + c1.Read(out) + fmt.Println(hex.EncodeToString(out)) + + // Example 4: Next read produces different result + c1.Read(out) + fmt.Println(hex.EncodeToString(out)) + + // Output: + //a90a4c6ca9af2156eba43dc8398279e6b60dcd56fb21837afe6c308fd4ceb05b + //a8db03e71f3e4da5c4eee9d28333cdd355f51cef3c567e59be5beb4ecdbb28f0 + //a90a4c6ca9af2156eba43dc8398279e6b60dcd56fb21837afe6c308fd4ceb05b9dd98c6ee866ca7dc5a39d53e960f400bcd5a19c8a2d6ec6459f63696543a0d8 + //85e73a72228d08b46515553ca3a29d47df3047e5d84b12d6c2c63e579f4fd1105716b7838e92e981863907f434bfd4443c9e56ea09da998d2f9b47db71988109 +} diff --git a/sha3/shake.go b/sha3/shake.go index 97c9b0624a..a39e5d5129 100644 --- a/sha3/shake.go +++ b/sha3/shake.go @@ -5,10 +5,18 @@ package sha3 // This file defines the ShakeHash interface, and provides -// functions for creating SHAKE instances, as well as utility +// functions for creating SHAKE and cSHAKE instances, as well as utility // functions for hashing bytes to arbitrary-length output. +// +// +// SHAKE implementation is based on FIPS PUB 202 [1] +// cSHAKE implementations is based on NIST SP 800-185 [2] +// +// [1] https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf +// [2] https://doi.org/10.6028/NIST.SP.800-185 import ( + "encoding/binary" "io" ) @@ -31,8 +39,77 @@ type ShakeHash interface { Reset() } -func (d *state) Clone() ShakeHash { - return d.clone() +// cSHAKE specific context +type cshakeState struct { + state // SHA-3 state context and Read/Write operations + + // initBlock is the cSHAKE specific initialization set of bytes. It is initialized + // by newCShake function and stores concatenation of N followed by S, encoded + // by the method specified in 3.3 of [1]. + // It is stored here in order for Reset() to be able to put context into + // initial state. + initBlock []byte +} + +// Consts for configuring initial SHA-3 state +const ( + dsbyteShake = 0x1f + dsbyteCShake = 0x04 + rate128 = 168 + rate256 = 136 +) + +func bytepad(input []byte, w int) []byte { + // leftEncode always returns max 9 bytes + buf := make([]byte, 0, 9+len(input)+w) + buf = append(buf, leftEncode(uint64(w))...) + buf = append(buf, input...) + padlen := w - (len(buf) % w) + return append(buf, make([]byte, padlen)...) +} + +func leftEncode(value uint64) []byte { + var b [9]byte + binary.BigEndian.PutUint64(b[1:], value) + // Trim all but last leading zero bytes + i := byte(1) + for i < 8 && b[i] == 0 { + i++ + } + // Prepend number of encoded bytes + b[i-1] = 9 - i + return b[i-1:] +} + +func newCShake(N, S []byte, rate int, dsbyte byte) ShakeHash { + c := cshakeState{state: state{rate: rate, dsbyte: dsbyte}} + + // leftEncode returns max 9 bytes + c.initBlock = make([]byte, 0, 9*2+len(N)+len(S)) + c.initBlock = append(c.initBlock, leftEncode(uint64(len(N)*8))...) + c.initBlock = append(c.initBlock, N...) + c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...) + c.initBlock = append(c.initBlock, S...) + c.Write(bytepad(c.initBlock, c.rate)) + return &c +} + +// Reset resets the hash to initial state. +func (c *cshakeState) Reset() { + c.state.Reset() + c.Write(bytepad(c.initBlock, c.rate)) +} + +// Clone returns copy of a cSHAKE context within its current state. +func (c *cshakeState) Clone() ShakeHash { + b := make([]byte, len(c.initBlock)) + copy(b, c.initBlock) + return &cshakeState{state: *c.clone(), initBlock: b} +} + +// Clone returns copy of SHAKE context within its current state. +func (c *state) Clone() ShakeHash { + return c.clone() } // NewShake128 creates a new SHAKE128 variable-output-length ShakeHash. @@ -42,7 +119,7 @@ func NewShake128() ShakeHash { if h := newShake128Asm(); h != nil { return h } - return &state{rate: 168, dsbyte: 0x1f} + return &state{rate: rate128, dsbyte: dsbyteShake} } // NewShake256 creates a new SHAKE256 variable-output-length ShakeHash. @@ -52,7 +129,33 @@ func NewShake256() ShakeHash { if h := newShake256Asm(); h != nil { return h } - return &state{rate: 136, dsbyte: 0x1f} + return &state{rate: rate256, dsbyte: dsbyteShake} +} + +// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash, +// a customizable variant of SHAKE128. +// N is used to define functions based on cSHAKE, it can be empty when plain cSHAKE is +// desired. S is a customization byte string used for domain separation - two cSHAKE +// computations on same input with different S yield unrelated outputs. +// When N and S are both empty, this is equivalent to NewShake128. +func NewCShake128(N, S []byte) ShakeHash { + if len(N) == 0 && len(S) == 0 { + return NewShake128() + } + return newCShake(N, S, rate128, dsbyteCShake) +} + +// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash, +// a customizable variant of SHAKE256. +// N is used to define functions based on cSHAKE, it can be empty when plain cSHAKE is +// desired. S is a customization byte string used for domain separation - two cSHAKE +// computations on same input with different S yield unrelated outputs. +// When N and S are both empty, this is equivalent to NewShake256. +func NewCShake256(N, S []byte) ShakeHash { + if len(N) == 0 && len(S) == 0 { + return NewShake256() + } + return newCShake(N, S, rate256, dsbyteCShake) } // ShakeSum128 writes an arbitrary-length digest of data into hash. diff --git a/sha3/testdata/keccakKats.json.deflate b/sha3/testdata/keccakKats.json.deflate index 62e85ae242..7a94c2f8bc 100644 Binary files a/sha3/testdata/keccakKats.json.deflate and b/sha3/testdata/keccakKats.json.deflate differ