Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: implement FixedLengthSum function for sha3 #1379

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions std/hash/sha3/hashes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
// New256 creates a new SHA3-256 hash.
// Its generic security strength is 256 bits against preimage attacks,
// and 128 bits against collision attacks.
func New256(api frontend.API) (hash.BinaryHasher, error) {
func New256(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
api: api,
uapi: uapi,
state: newState(),
dsbyte: 0x06,
Expand All @@ -26,12 +27,13 @@ func New256(api frontend.API) (hash.BinaryHasher, error) {
// New384 creates a new SHA3-384 hash.
// Its generic security strength is 384 bits against preimage attacks,
// and 192 bits against collision attacks.
func New384(api frontend.API) (hash.BinaryHasher, error) {
func New384(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
api: api,
uapi: uapi,
state: newState(),
dsbyte: 0x06,
Expand All @@ -43,12 +45,13 @@ func New384(api frontend.API) (hash.BinaryHasher, error) {
// New512 creates a new SHA3-512 hash.
// Its generic security strength is 512 bits against preimage attacks,
// and 256 bits against collision attacks.
func New512(api frontend.API) (hash.BinaryHasher, error) {
func New512(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
api: api,
uapi: uapi,
state: newState(),
dsbyte: 0x06,
Expand All @@ -61,12 +64,13 @@ func New512(api frontend.API) (hash.BinaryHasher, error) {
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) {
func NewLegacyKeccak256(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
api: api,
uapi: uapi,
state: newState(),
dsbyte: 0x01,
Expand All @@ -79,12 +83,13 @@ func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) {
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512(api frontend.API) (hash.BinaryHasher, error) {
func NewLegacyKeccak512(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
api: api,
uapi: uapi,
state: newState(),
dsbyte: 0x01,
Expand Down
61 changes: 61 additions & 0 deletions std/hash/sha3/sha3.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package sha3

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/keccakf"
"math/big"
)

type digest struct {
api frontend.API
uapi *uints.BinaryField[uints.U64]
state [25]uints.U64 // 1600 bits state: 25 x 64
in []uints.U8 // input to be digested
Expand All @@ -27,11 +31,68 @@ func (d *digest) Reset() {

func (d *digest) Sum() []uints.U8 {
padded := d.padding()

blocks := d.composeBlocks(padded)
d.absorbing(blocks)
return d.squeezeBlocks()
}

func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
// padding
padded := make([]uints.U8, len(d.in))
copy(padded[:], d.in[:])
padded = append(padded, uints.NewU8Array(make([]uint8, d.rate))...)
numberOfBlocks := frontend.Variable(0)

for i := 0; i < len(padded)-d.rate; i++ {
reachEnd := cmp.IsEqual(d.api, i+1, length)
switch q := d.rate - ((i + 1) % d.rate); q {
case 1:
padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte^0x80, padded[i+1].Val)
numberOfBlocks = d.api.Select(reachEnd, (i+2)/d.rate, numberOfBlocks)
case 2:
padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte, padded[i+1].Val)
padded[i+2].Val = d.api.Select(reachEnd, 0x80, padded[i+2].Val)
numberOfBlocks = d.api.Select(reachEnd, (i+3)/d.rate, numberOfBlocks)
default:
padded[i+1].Val = d.api.Select(reachEnd, d.dsbyte, padded[i+1].Val)
for j := 0; j < q-2; j++ {
padded[i+2+j].Val = d.api.Select(reachEnd, 0, padded[i+2+j].Val)
}
padded[i+q].Val = d.api.Select(reachEnd, 0x80, padded[i+q].Val)
numberOfBlocks = d.api.Select(reachEnd, (i+1+q)/d.rate, numberOfBlocks)
}
}

// compose blocks
blocks := d.composeBlocks(padded)

// absorbing
var state [25]uints.U64
var resultState [25]uints.U64
copy(resultState[:], d.state[:])
copy(state[:], d.state[:])

comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(blocks))), false)

for i, block := range blocks {
for j := range block {
state[j] = d.uapi.Xor(state[j], block[j])
}
state = keccakf.Permute(d.uapi, state)
isInRange := comparator.IsLess(i, numberOfBlocks)
for j := 0; j < 25; j++ {
for k := 0; k < 8; k++ {
resultState[j][k].Val = d.api.Select(isInRange, state[j][k].Val, resultState[j][k].Val)
}
}
}
copy(d.state[:], resultState[:])

// squeeze blocks
return d.squeezeBlocks()
}

func (d *digest) padding() []uints.U8 {
padded := make([]uints.U8, len(d.in))
copy(padded[:], d.in[:])
Expand Down
70 changes: 67 additions & 3 deletions std/hash/sha3/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sha3
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/sha3"
"hash"
"testing"

Expand All @@ -11,11 +12,10 @@ import (
zkhash "github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/test"
"golang.org/x/crypto/sha3"
)

type testCase struct {
zk func(api frontend.API) (zkhash.BinaryHasher, error)
zk func(api frontend.API) (zkhash.BinaryFixedLengthHasher, error)
native func() hash.Hash
}

Expand Down Expand Up @@ -59,7 +59,7 @@ func (c *sha3Circuit) Define(api frontend.API) error {

func TestSHA3(t *testing.T) {
assert := test.NewAssert(t)
in := make([]byte, 310)
in := make([]byte, 100)
_, err := rand.Reader.Read(in)
assert.NoError(err)

Expand Down Expand Up @@ -88,3 +88,67 @@ func TestSHA3(t *testing.T) {
}, name)
}
}

type sha3FixedLengthSumCircuit struct {
In []uints.U8
Expected []uints.U8
Length frontend.Variable
hasher string
}

func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error {
newHasher, ok := testCases[c.hasher]
if !ok {
return fmt.Errorf("hash function unknown: %s", c.hasher)
}
h, err := newHasher.zk(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U64](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.FixedLengthSum(c.Length)

for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA3FixedLengthSum(t *testing.T) {
assert := test.NewAssert(t)
in := make([]byte, 310)
_, err := rand.Reader.Read(in)
assert.NoError(err)

for name := range testCases {
assert.Run(func(assert *test.Assert) {
name := name
strategy := testCases[name]
h := strategy.native()
length := len(in) - 10
h.Write(in[:length])
expected := h.Sum(nil)

circuit := &sha3FixedLengthSumCircuit{
In: make([]uints.U8, len(in)),
Expected: make([]uints.U8, len(expected)),
Length: 0,
hasher: name,
}

witness := &sha3FixedLengthSumCircuit{
In: uints.NewU8Array(in),
Expected: uints.NewU8Array(expected),
Length: length,
}

if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil {
t.Fatalf("%s: %s", name, err)
}
}, name)
}
}
7 changes: 7 additions & 0 deletions std/math/cmp/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
"math/big"
)

// IsEqual returns 1 if a = b, and returns 0 if a != b. a and b should be
// integers in range [0, P-1], where P is the order of the underlying field used
// by the proof system.
func IsEqual(api frontend.API, a, b frontend.Variable) frontend.Variable {
return api.IsZero(api.Sub(a, b))
}

// IsLess returns 1 if a < b, and returns 0 if a >= b. a and b should be
// integers in range [0, P-1], where P is the order of the underlying field used
// by the proof system.
Expand Down