diff --git a/std/hash/sha3/hashes.go b/std/hash/sha3/hashes.go index b2282f9108..4fac328955 100644 --- a/std/hash/sha3/hashes.go +++ b/std/hash/sha3/hashes.go @@ -9,12 +9,13 @@ import ( // New256 creates a new SHA3-256 hash. // Its generic security strength is 256 bits against preimage attacks, // and 128 bits against collision attacks. -func New256(api frontend.API) (hash.BinaryHasher, error) { +func New256(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U64](api) if err != nil { return nil, err } return &digest{ + api: api, uapi: uapi, state: newState(), dsbyte: 0x06, @@ -26,12 +27,13 @@ func New256(api frontend.API) (hash.BinaryHasher, error) { // New384 creates a new SHA3-384 hash. // Its generic security strength is 384 bits against preimage attacks, // and 192 bits against collision attacks. -func New384(api frontend.API) (hash.BinaryHasher, error) { +func New384(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U64](api) if err != nil { return nil, err } return &digest{ + api: api, uapi: uapi, state: newState(), dsbyte: 0x06, @@ -43,12 +45,13 @@ func New384(api frontend.API) (hash.BinaryHasher, error) { // New512 creates a new SHA3-512 hash. // Its generic security strength is 512 bits against preimage attacks, // and 256 bits against collision attacks. -func New512(api frontend.API) (hash.BinaryHasher, error) { +func New512(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U64](api) if err != nil { return nil, err } return &digest{ + api: api, uapi: uapi, state: newState(), dsbyte: 0x06, @@ -61,12 +64,13 @@ func New512(api frontend.API) (hash.BinaryHasher, error) { // // Only use this function if you require compatibility with an existing cryptosystem // that uses non-standard padding. All other users should use New256 instead. -func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) { +func NewLegacyKeccak256(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U64](api) if err != nil { return nil, err } return &digest{ + api: api, uapi: uapi, state: newState(), dsbyte: 0x01, @@ -79,12 +83,13 @@ func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) { // // Only use this function if you require compatibility with an existing cryptosystem // that uses non-standard padding. All other users should use New512 instead. -func NewLegacyKeccak512(api frontend.API) (hash.BinaryHasher, error) { +func NewLegacyKeccak512(api frontend.API) (hash.BinaryFixedLengthHasher, error) { uapi, err := uints.New[uints.U64](api) if err != nil { return nil, err } return &digest{ + api: api, uapi: uapi, state: newState(), dsbyte: 0x01, diff --git a/std/hash/sha3/sha3.go b/std/hash/sha3/sha3.go index 76cd2c8a0a..dfce4c0762 100644 --- a/std/hash/sha3/sha3.go +++ b/std/hash/sha3/sha3.go @@ -1,11 +1,15 @@ package sha3 import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/cmp" "github.com/consensys/gnark/std/math/uints" "github.com/consensys/gnark/std/permutation/keccakf" + "math/big" ) type digest struct { + api frontend.API uapi *uints.BinaryField[uints.U64] state [25]uints.U64 // 1600 bits state: 25 x 64 in []uints.U8 // input to be digested @@ -27,11 +31,68 @@ func (d *digest) Reset() { func (d *digest) Sum() []uints.U8 { padded := d.padding() + blocks := d.composeBlocks(padded) d.absorbing(blocks) return d.squeezeBlocks() } +func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { + // padding + padded := make([]uints.U8, len(d.in)) + copy(padded[:], d.in[:]) + padded = append(padded, uints.NewU8Array(make([]uint8, d.rate))...) + numberOfBlocks := frontend.Variable(0) + + for i := 0; i < len(padded)-d.rate; i++ { + reachEnd := cmp.IsEqual(d.api, i+1, length) + switch q := d.rate - ((i + 1) % d.rate); q { + case 1: + padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte^0x80, padded[i+1].Val) + numberOfBlocks = d.api.Select(reachEnd, (i+2)/d.rate, numberOfBlocks) + case 2: + padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte, padded[i+1].Val) + padded[i+2].Val = d.api.Select(reachEnd, 0x80, padded[i+2].Val) + numberOfBlocks = d.api.Select(reachEnd, (i+3)/d.rate, numberOfBlocks) + default: + padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte, padded[i+1].Val) + for j := 0; j < q-2; j++ { + padded[i+2+j].Val = d.api.Select(reachEnd, 0, padded[i+2+j].Val) + } + padded[i+q].Val = d.api.Select(reachEnd, 0x80, padded[i+q].Val) + numberOfBlocks = d.api.Select(reachEnd, (i+1+q)/d.rate, numberOfBlocks) + } + } + + // compose blocks + blocks := d.composeBlocks(padded) + + // absorbing + var state [25]uints.U64 + var resultState [25]uints.U64 + copy(resultState[:], d.state[:]) + copy(state[:], d.state[:]) + + comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(blocks))), false) + + for i, block := range blocks { + for j := range block { + state[j] = d.uapi.Xor(state[j], block[j]) + } + state = keccakf.Permute(d.uapi, state) + isInRange := comparator.IsLess(i, numberOfBlocks) + for j := 0; j < 25; j++ { + for k := 0; k < 8; k++ { + resultState[j][k].Val = d.api.Select(isInRange, state[j][k].Val, resultState[j][k].Val) + } + } + } + copy(d.state[:], resultState[:]) + + // squeeze blocks + return d.squeezeBlocks() +} + func (d *digest) padding() []uints.U8 { padded := make([]uints.U8, len(d.in)) copy(padded[:], d.in[:]) diff --git a/std/hash/sha3/sha3_test.go b/std/hash/sha3/sha3_test.go index 0336746519..234148973c 100644 --- a/std/hash/sha3/sha3_test.go +++ b/std/hash/sha3/sha3_test.go @@ -3,6 +3,7 @@ package sha3 import ( "crypto/rand" "fmt" + "golang.org/x/crypto/sha3" "hash" "testing" @@ -11,11 +12,10 @@ import ( zkhash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/math/uints" "github.com/consensys/gnark/test" - "golang.org/x/crypto/sha3" ) type testCase struct { - zk func(api frontend.API) (zkhash.BinaryHasher, error) + zk func(api frontend.API) (zkhash.BinaryFixedLengthHasher, error) native func() hash.Hash } @@ -59,7 +59,7 @@ func (c *sha3Circuit) Define(api frontend.API) error { func TestSHA3(t *testing.T) { assert := test.NewAssert(t) - in := make([]byte, 310) + in := make([]byte, 100) _, err := rand.Reader.Read(in) assert.NoError(err) @@ -88,3 +88,67 @@ func TestSHA3(t *testing.T) { }, name) } } + +type sha3FixedLengthSumCircuit struct { + In []uints.U8 + Expected []uints.U8 + Length frontend.Variable + hasher string +} + +func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error { + newHasher, ok := testCases[c.hasher] + if !ok { + return fmt.Errorf("hash function unknown: %s", c.hasher) + } + h, err := newHasher.zk(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + h.Write(c.In) + res := h.FixedLengthSum(c.Length) + + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestSHA3FixedLengthSum(t *testing.T) { + assert := test.NewAssert(t) + in := make([]byte, 310) + _, err := rand.Reader.Read(in) + assert.NoError(err) + + for name := range testCases { + assert.Run(func(assert *test.Assert) { + name := name + strategy := testCases[name] + h := strategy.native() + length := len(in) - 10 + h.Write(in[:length]) + expected := h.Sum(nil) + + circuit := &sha3FixedLengthSumCircuit{ + In: make([]uints.U8, len(in)), + Expected: make([]uints.U8, len(expected)), + Length: 0, + hasher: name, + } + + witness := &sha3FixedLengthSumCircuit{ + In: uints.NewU8Array(in), + Expected: uints.NewU8Array(expected), + Length: length, + } + + if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil { + t.Fatalf("%s: %s", name, err) + } + }, name) + } +} diff --git a/std/math/cmp/generic.go b/std/math/cmp/generic.go index 97bac20567..016e0387f5 100644 --- a/std/math/cmp/generic.go +++ b/std/math/cmp/generic.go @@ -7,6 +7,13 @@ import ( "math/big" ) +// IsEqual returns 1 if a = b, and returns 0 if a != b. a and b should be +// integers in range [0, P-1], where P is the order of the underlying field used +// by the proof system. +func IsEqual(api frontend.API, a, b frontend.Variable) frontend.Variable { + return api.IsZero(api.Sub(a, b)) +} + // IsLess returns 1 if a < b, and returns 0 if a >= b. a and b should be // integers in range [0, P-1], where P is the order of the underlying field used // by the proof system.