Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gkr_nonnative intial review #1162

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
522 changes: 522 additions & 0 deletions frontend/variable.go

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions std/fiat-shamir/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fiatshamir
import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/emulated"
)

type Settings struct {
Expand All @@ -12,6 +13,13 @@ type Settings struct {
Hash hash.FieldHasher
}

type SettingsFr[FR emulated.FieldParams] struct {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
Transcript *Transcript
Prefix string
BaseChallenges []emulated.Element[FR]
Hash hash.FieldHasher
}

func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings {
return Settings{
Transcript: transcript,
Expand All @@ -20,9 +28,24 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro
}
}

func WithTranscriptFr[FR emulated.FieldParams](transcript *Transcript, prefix string, baseChallenges ...emulated.Element[FR]) SettingsFr[FR] {
return SettingsFr[FR]{
Transcript: transcript,
Prefix: prefix,
BaseChallenges: baseChallenges,
}
}

func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings {
return Settings{
BaseChallenges: baseChallenges,
Hash: hash,
}
}

func WithHashFr[FR emulated.FieldParams](hash hash.FieldHasher, baseChallenges ...emulated.Element[FR]) SettingsFr[FR] {
return SettingsFr[FR]{
BaseChallenges: baseChallenges,
Hash: hash,
}
}
1 change: 1 addition & 0 deletions std/gkr/gkr.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof,
claims := newClaimsManager(c, assignment)

var firstChallenge []frontend.Variable
// why no bind values here?
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
if err != nil {
return err
Expand Down
12 changes: 11 additions & 1 deletion std/gkr/gkr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
"reflect"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/profile"
fiatshamir "github.com/consensys/gnark/std/fiat-shamir"
"github.com/consensys/gnark/std/polynomial"
"github.com/consensys/gnark/test"
Expand Down Expand Up @@ -74,6 +77,14 @@ func generateTestVerifier(path string, options ...option) func(t *testing.T) {
TestCaseName: path,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Omit debugging code.

p:= profile.Start()
frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, validCircuit)
p.Stop()

fmt.Println(p.NbConstraints())
fmt.Println(p.Top())
//r1cs.CheckUnconstrainedWires()

invalidCircuit := &GkrVerifierCircuit{
Input: make([][]frontend.Variable, len(testCase.Input)),
Output: make([][]frontend.Variable, len(testCase.Output)),
Expand Down Expand Up @@ -327,7 +338,6 @@ func TestLoadCircuit(t *testing.T) {
assert.Equal(t, []*Wire{}, c[0].Inputs)
assert.Equal(t, []*Wire{&c[0]}, c[1].Inputs)
assert.Equal(t, []*Wire{&c[1]}, c[2].Inputs)

}

func TestTopSortTrivial(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ func (e *Element[T]) copy() *Element[T] {
r.overflow = e.overflow
r.internal = e.internal
return &r
}
}
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
if uint(f.api.Compiler().FieldBitLen()) < 2*f.fParams.BitsPerLimb()+1 {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}

println("NewField mulcheck")
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
native.Compiler().Defer(f.performMulChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
Expand Down
10 changes: 10 additions & 0 deletions std/math/emulated/field_mul.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] {
return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b)
}

// // MulAcc computes a*b and reduces it modulo the field order. The returned Element
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
// // has default number of limbs and zero overflow. If the result wouldn't fit
// // into Element, then locally reduces the inputs first. Doesn't mutate inputs.
// //
// // For multiplying by a constant, use [Field[T].MulConst] method which is more
// // efficient.
// func (f *Field[T]) MulAcc(a, b *Element[T], c *Element[T]) *Element[T] {
// return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b)
// }

// MulMod computes a*b and reduces it modulo the field order. The returned Element
// has default number of limbs and zero overflow.
//
Expand Down
16 changes: 16 additions & 0 deletions std/math/polynomial/polynomial.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ type Univariate[FR emulated.FieldParams] []emulated.Element[FR]
// coefficients.
type Multilinear[FR emulated.FieldParams] []emulated.Element[FR]

func (ml *Multilinear[FR]) NumVars() int {
return bits.Len(uint(len(*ml) - 1))
}

func valueOf[FR emulated.FieldParams](univ []*big.Int) []emulated.Element[FR] {
ret := make([]emulated.Element[FR], len(univ))
for i := range univ {
Expand Down Expand Up @@ -89,6 +93,18 @@ func New[FR emulated.FieldParams](api frontend.API) (*Polynomial[FR], error) {
}, nil
}

func (p *Polynomial[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not need to have Mul, Add and AssertIsEqual as methods on Polynomial. You can directly call emulated.Field methods Mul, Add, AssertIsEqual on your inputs. I'd recommend removing methods here.

Copy link
Author

@amit0365 amit0365 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Now creating a new instance of field with emulated.Field[FR]{} and using the api from there.

return p.f.Mul(a, b)
}

func (p *Polynomial[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] {
return p.f.Add(a, b)
}

func (p *Polynomial[FR]) AssertIsEqual(a, b *emulated.Element[FR]) {
p.f.AssertIsEqual(a, b)
}

// EvalUnivariate evaluates univariate polynomial at a point at. It returns the
// evaluation. The method does not mutate the inputs.
func (p *Polynomial[FR]) EvalUnivariate(P Univariate[FR], at *emulated.Element[FR]) *emulated.Element[FR] {
Expand Down
106 changes: 104 additions & 2 deletions std/polynomial/polynomial.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,78 @@ import (
"math/bits"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark-crypto/utils"
)

type Polynomial []frontend.Variable
type MultiLin []frontend.Variable

var minFoldScaledLogSize = 16

func FromSlice(s []frontend.Variable) []*frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
r := make([]*frontend.Variable, len(s))
for i := range s {
r[i] = &s[i]
}
return r
}

// FromSliceReferences maps slice of emulated element references to their values.
func FromSliceReferences(in []*frontend.Variable) []frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
r := make([]frontend.Variable, len(in))
for i := range in {
r[i] = *in[i]
}
return r
}

func _clone(m MultiLin, p *Pool) MultiLin {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
if p == nil {
return m.Clone()
} else {
return p.Clone(m)
}
}

func _dump(m MultiLin, p *Pool) {
if p != nil {
p.Dump(m)
}
}

// Evaluate assumes len(m) = 1 << len(at)
// it doesn't modify m
func (m MultiLin) EvaluatePool(api frontend.API, at []frontend.Variable, pool *Pool) frontend.Variable {
amit0365 marked this conversation as resolved.
Show resolved Hide resolved
_m := _clone(m, pool)

/*minFoldScaledLogSize := 16
if api is r1cs {
minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs
}*/

scaleCorrectionFactor := frontend.Variable(1)
// at each iteration fold by at[i]
for len(_m) > 1 {
if len(_m) >= minFoldScaledLogSize {
scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0]))
} else {
_m.Fold(api, at[0])
}
_m = _m[:len(_m)/2]
at = at[1:]
}

if len(at) != 0 {
panic("incompatible evaluation vector size")
}

result := _m[0]

_dump(_m, pool)

return api.Mul(result, scaleCorrectionFactor)
}

// Evaluate assumes len(m) = 1 << len(at)
// it doesn't modify m
func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable {
Expand All @@ -27,7 +92,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va
if len(_m) >= minFoldScaledLogSize {
scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0]))
} else {
_m.fold(api, at[0])
_m.Fold(api, at[0])
}
_m = _m[:len(_m)/2]
at = at[1:]
Expand All @@ -42,7 +107,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va

// fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size
// WARNING: The user should halve m themselves after the call
func (m MultiLin) fold(api frontend.API, at frontend.Variable) {
func (m MultiLin) Fold(api frontend.API, at frontend.Variable) {
zero := m[:len(m)/2]
one := m[len(m)/2:]
for j := range zero {
Expand All @@ -51,6 +116,43 @@ func (m MultiLin) fold(api frontend.API, at frontend.Variable) {
}
}

func (m *MultiLin) FoldParallel(api frontend.API, r frontend.Variable) utils.Task {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need it - circuit compilation happen sequentially anyway and you cannot parallelize it.

mid := len(*m) / 2
bottom, top := (*m)[:mid], (*m)[mid:]

*m = bottom

return func(start, end int) {
var t frontend.Variable // no need to update the top part
for i := start; i < end; i++ {
// table[i] ← table[i] + r (table[i + mid] - table[i])
t = api.Sub(&top[i], &bottom[i])
t = api.Mul(&t, &r)
bottom[i] = api.Add(&bottom[i], &t)
}
}
}

// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0]
func (m *MultiLin) Eq(api frontend.API, q []frontend.Variable) {
n := len(q)

if len(*m) != 1<<n {
panic("destination must have size 2 raised to the size of source")
}

//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
for j := 0; j < (1 << i); j++ {
j0 := j << (n - i) // bᵢ₊₁ = 0
j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
(*m)[j1] = api.Mul((*m)[j1], q[i]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
(*m)[j0] = api.Sub((*m)[j0], (*m)[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
}
}
}

// foldScaled(m, at) = fold(m, at) / (1 - at)
// it returns 1 - at, for convenience
func (m MultiLin) foldScaled(api frontend.API, at frontend.Variable) (denom frontend.Variable) {
Expand Down
2 changes: 1 addition & 1 deletion std/polynomial/polynomial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c *foldMultiLinCircuit) Define(api frontend.API) error {
return errors.New("folding size mismatch")
}
m := MultiLin(c.M)
m.fold(api, c.At)
m.Fold(api, c.At)
for i := range c.Result {
api.AssertIsEqual(m[i], c.Result[i])
}
Expand Down
Loading
Loading