diff --git a/.gitignore b/.gitignore index 1be9af1ba2..8abda9a03b 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,5 @@ gnarkd/circuits/** go.work go.work.sum -examples/gbotrel/** \ No newline at end of file +examples/gbotrel/** +build/ \ No newline at end of file diff --git a/README.md b/README.md index c86eedac50..35787c62db 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,40 @@ func main() { ### GPU Support +#### Zeknox Library +Unlock free GPU acceleration with [OKX zeknox library](https://github.com/okx/zeknox). + +##### Download prebuilt binaries +```sh +curl -L -o libzeknox.a https://github.com/okx/zeknox/releases/download/v1.0.0/bn254-msm-86-89-90-libzeknox.a +curl -L -o libblst.a https://github.com/okx/zeknox/releases/download/v1.0.0/libblst.a +sudo cp libblst.a libzeknox.a /usr/local/lib/ +``` + +If you want to build from source, see guide in https://github.com/okx/zeknox. + +##### Enjoy GPU + +`groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` + +```sh +go run -tags=zeknox examples/zeknox/main.go +# (place -tags before the filename) +``` + +##### Test +Add following code to [mimc_test.go](examples/mimc/mimc_test.go) +```go +assert.ProverSucceeded(&mimcCircuit, &Circuit{ + PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", + Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", + }, test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) +``` + +```sh +go test github.com/consensys/gnark/examples/mimc -tags=prover_checks,zeknox +``` + #### Icicle Library The following schemes and curves support experimental use of Ingonyama's Icicle GPU library for low level zk-SNARK primitives such as MSM, NTT, and polynomial operations: @@ -178,7 +212,7 @@ You can then toggle on or off icicle acceleration by providing the `WithIcicleAc ```go // toggle on proofIci, err := groth16.Prove(ccs, pk, secretWitness, backend.WithIcicleAcceleration()) - + // toggle off proof, err := groth16.Prove(ccs, pk, secretWitness) ``` diff --git a/backend/backend.go b/backend/backend.go index 5bcfc7bc20..52da9c5671 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -133,6 +133,19 @@ func WithProverKZGFoldingHashFunction(hFunc hash.Hash) ProverOption { } } +// WithZeknoxAcceleration requests to use [ZEKNOX] GPU proving backend for the +// prover. This option requires that the program is compiled with `zeknox` build +// tag and the ZEKNOX dependencies are properly installed. See [ZEKNOX] for +// installation description. +// +// [ZEKNOX]: https://github.com/okx/zeknox +func WithZeknoxAcceleration() ProverOption { + return func(pc *ProverConfig) error { + pc.Accelerator = "zeknox" + return nil + } +} + // WithIcicleAcceleration requests to use [ICICLE] GPU proving backend for the // prover. This option requires that the program is compiled with `icicle` build // tag and the ICICLE dependencies are properly installed. See [ICICLE] for diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 5f0d413133..3cb96cdc96 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -142,7 +142,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var h []fr.Element chHDone := make(chan struct{}, 1) go func() { + startH := time.Now() h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") solution.A = nil solution.B = nil solution.C = nil diff --git a/backend/groth16/bn254/zeknox/doc.go b/backend/groth16/bn254/zeknox/doc.go new file mode 100644 index 0000000000..2200b550c7 --- /dev/null +++ b/backend/groth16/bn254/zeknox/doc.go @@ -0,0 +1,2 @@ +// Package zeknox_bn254 implements zeknox acceleration for BN254 Groth16 backend. +package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/marshal_test.go b/backend/groth16/bn254/zeknox/marshal_test.go new file mode 100644 index 0000000000..5e9b2aeaea --- /dev/null +++ b/backend/groth16/bn254/zeknox/marshal_test.go @@ -0,0 +1,67 @@ +package zeknox_bn254_test + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" + cs_bn254 "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type circuit struct { + A, B frontend.Variable `gnark:",public"` + Res frontend.Variable +} + +func (c *circuit) Define(api frontend.API) error { + api.AssertIsEqual(api.Mul(c.A, c.B), c.Res) + return nil +} + +func TestMarshalNative(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + nativePK := groth16_bn254.ProvingKey{} + nativeVK := groth16_bn254.VerifyingKey{} + err = groth16_bn254.Setup(tCcs, &nativePK, &nativeVK) + assert.NoError(err) + + pk := groth16.NewProvingKey(ecc.BN254) + buf := new(bytes.Buffer) + _, err = nativePK.WriteTo(buf) + assert.NoError(err) + _, err = pk.ReadFrom(buf) + assert.NoError(err) + if pk.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} + +func TestMarshalZeknox(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + zePK := zeknox_bn254.ProvingKey{} + VK := groth16_bn254.VerifyingKey{} + err = zeknox_bn254.Setup(tCcs, &zePK, &VK) + assert.NoError(err) + + nativePK := groth16_bn254.ProvingKey{} + buf := new(bytes.Buffer) + _, err = zePK.WriteTo(buf) + assert.NoError(err) + _, err = nativePK.ReadFrom(buf) + assert.NoError(err) + if zePK.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go new file mode 100644 index 0000000000..b920f92c00 --- /dev/null +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -0,0 +1,34 @@ +//go:build !zeknox + +package zeknox_bn254 + +import ( + "fmt" + + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/witness" + cs "github.com/consensys/gnark/constraint/bn254" +) + +type ProvingKey struct { + groth16_bn254.ProvingKey +} + +const HasZeknox = false + +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + return nil, fmt.Errorf("zeknox backend requested but program compiled without 'zeknox' build tag") +} + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} + +func (pk *ProvingKey) Free() { + // nothing to do here +} diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go new file mode 100644 index 0000000000..32f80dfa51 --- /dev/null +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -0,0 +1,55 @@ +//go:build zeknox + +package zeknox_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/okx/zeknox/wrappers/go/device" +) + +type deviceInfo struct { + G1Device struct { + A, B, K, Z DevicePoints[bn254.G1Affine] + } + G2Device struct { + B DevicePoints[bn254.G2Affine] + } + InfinityPointIndicesK []int +} + +type ProvingKey struct { + groth16_bn254.ProvingKey + *deviceInfo +} + +type DevicePoints[T bn254.G1Affine | bn254.G2Affine] struct { + *device.HostOrDeviceSlice[T] + // Gnark points are in Montgomery form + // After 1 GPU MSM, points in GPU are converted to affine form + // Pass it to MSM config + Mont bool +} + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} + +// You should call this method to free the GPU memory +// +// pk := groth16.NewProvingKey(ecc.BN254) +// defer pk.(*zeknox_bn254.ProvingKey).Free() +func (pk *ProvingKey) Free() { + if pk.deviceInfo != nil { + pk.deviceInfo.G1Device.A.Free() + pk.deviceInfo.G1Device.B.Free() + pk.deviceInfo.G1Device.K.Free() + pk.deviceInfo.G1Device.Z.Free() + pk.deviceInfo.G2Device.B.Free() + } +} diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go new file mode 100644 index 0000000000..52ece00523 --- /dev/null +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -0,0 +1,540 @@ +//go:build zeknox + +package zeknox_bn254 + +import ( + "context" + "fmt" + "math/big" + "runtime" + "time" + "unsafe" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field" + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/constraint/solver" + fcs "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + "github.com/okx/zeknox/wrappers/go/device" + "github.com/okx/zeknox/wrappers/go/msm" + "golang.org/x/sync/errgroup" +) + +const HasZeknox = true + +// Use single GPU +const deviceId = 0 + +func (pk *ProvingKey) setupDevicePointers() error { + if pk.deviceInfo != nil { + return nil + } + pk.deviceInfo = &deviceInfo{} + + // MSM G1 & G2 Device Setup + g, _ := errgroup.WithContext(context.TODO()) + // G1.A + deviceA := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.A, deviceA) }) + // G1.B + deviceG1B := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.B, deviceG1B) }) + + // G1.K + var pointsNoInfinity []curve.G1Affine + for i, gnarkPoint := range pk.G1.K { + if gnarkPoint.IsInfinity() { + pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i) + } else { + pointsNoInfinity = append(pointsNoInfinity, gnarkPoint) + } + } + deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) + + // G1.Z + deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.Z, deviceZ) }) + + // G2.B + deviceG2B := make(chan *device.HostOrDeviceSlice[curve.G2Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G2.B, deviceG2B) }) + + // wait for all points to be copied to the device + // if any of the copy failed, return the error + if err := g.Wait(); err != nil { + return err + } + // if no error, store device pointers in pk + pk.G1Device.A = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceA, + Mont: true, + } + pk.G1Device.B = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceG1B, + Mont: true, + } + pk.G1Device.K = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceK, + Mont: true, + } + pk.G1Device.Z = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceZ, + Mont: true, + } + pk.G2Device.B = DevicePoints[curve.G2Affine]{ + HostOrDeviceSlice: <-deviceG2B, + Mont: true, + } + + return nil +} + +// Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, fmt.Errorf("new prover config: %w", err) + } + if opt.HashToFieldFn == nil { + opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst)) + } + if opt.Accelerator != "zeknox" { + return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) + } + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + if pk.deviceInfo == nil { + start := time.Now() + if err := pk.setupDevicePointers(); err != nil { + return nil, fmt.Errorf("setup device pointers: %w", err) + } + log.Debug().Dur("took", time.Since(start)).Msg("Copy proving key to device") + } + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &groth16_bn254.Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] + + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + + // override hints + bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder) + solverOpts = append(solverOpts, solver.OverrideHint(bsb22ID, func(_ *big.Int, in []*big.Int, out []*big.Int) error { + i := int(in[0].Int64()) + in = in[1:] + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[+len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1)) + hashBts := opt.HashToFieldFn.Sum(nil) + opt.HashToFieldFn.Reset() + nbBuf := fr.Bytes + if opt.HashToFieldFn.Size() < fr.Bytes { + nbBuf = opt.HashToFieldFn.Size() + } + var res fr.Element + res.SetBytes(hashBts[:nbBuf]) + res.BigInt(out[0]) + return nil + })) + + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err + } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + + start := time.Now() + poks := make([]curve.G1Affine, len(pk.CommitmentKeys)) + + for i := range pk.CommitmentKeys { + var err error + if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil { + return nil, err + } + } + // compute challenge for folding the PoKs from the commitments + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1) + if err != nil { + return nil, err + } + if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil { + return nil, err + } + + // we need to copy and filter the wireValues for each multi exp + // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity + var wireValuesA, wireValuesB []fr.Element + chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) + + go func() { + wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + for i, j := 0, 0; j < len(wireValuesA); i++ { + if pk.InfinityA[i] { + continue + } + wireValuesA[j] = wireValues[i] + j++ + } + close(chWireValuesA) + }() + go func() { + wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + for i, j := 0, 0; j < len(wireValuesB); i++ { + if pk.InfinityB[i] { + continue + } + wireValuesB[j] = wireValues[i] + j++ + } + close(chWireValuesB) + }() + + // sample random r and s + var r, s big.Int + var _r, _s, _kr fr.Element + if _, err := _r.SetRandom(); err != nil { + return nil, err + } + if _, err := _s.SetRandom(); err != nil { + return nil, err + } + // -rs + // Why it is called kr? not rs? -> notation from DIZK paper + _kr.Mul(&_r, &_s).Neg(&_kr) + + _r.BigInt(&r) + _s.BigInt(&s) + + // computes r[δ], s[δ], kr[δ] + deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) + + var bs1, ar curve.G1Jac + + computeBS1 := func() error { + <-chWireValuesB + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err + } + wireB = <-chWireB + defer wireB.Free() + startBs1 := time.Now() + if err := gpuMsm(&bs1, &pk.G1Device.B, wireB); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") + // + beta + s[δ] + bs1.AddMixed(&pk.G1.Beta) + bs1.AddMixed(&deltas[1]) + return nil + } + + computeAR1 := func() error { + <-chWireValuesA + var wireA *device.HostOrDeviceSlice[fr.Element] + chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesA, chWireA); err != nil { + return err + } + wireA = <-chWireA + defer wireA.Free() + startAr := time.Now() + if err := gpuMsm(&ar, &pk.G1Device.A, wireA); err != nil { + return err + } + + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireA.Len()), time.Since(startAr)).Msg("ar done") + ar.AddMixed(&pk.G1.Alpha) + ar.AddMixed(&deltas[0]) + proof.Ar.FromJacobian(&ar) + return nil + } + + var krs2 curve.G1Jac + computeKRS2 := func() error { + // quotient poly H (witness reduction / FFT part) + var h []fr.Element + { + startH := time.Now() + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("took", time.Since(startH)).Msg("computed H") + solution.A = nil + solution.B = nil + solution.C = nil + } + // Copy h poly to device, since we haven't implemented FFT on device + var deviceH *device.HostOrDeviceSlice[fr.Element] + chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { + return err + } + deviceH = <-chDeviceH + defer deviceH.Free() + // MSM G1 Krs2 + startKrs2 := time.Now() + if err := gpuMsm(&krs2, &pk.G1Device.Z, deviceH); err != nil { + return err + } + + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") + return nil + } + + var krs1 curve.G1Jac + + computeKRS1_CPU := func() error { + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + // original Groth16 witness without pedersen commitment + wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + + // CPU + // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU, also reduce data copy + startKrs := time.Now() + if _, err := krs1.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("CPU krs done") + // -rs[δ] + krs1.AddMixed(&deltas[2]) + return nil + } + + computeBS2_CPU := func() error { + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac + + nbTasks := runtime.NumCPU() / 2 + if nbTasks <= 16 { + // if we don't have a lot of CPUs, this may artificially split the MSM + nbTasks *= 2 + } + <-chWireValuesB + startBs := time.Now() + if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG2 %d took", len(wireValuesB)), time.Since(startBs)).Msg("Bs.MultiExp done") + + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) + + proof.Bs.FromJacobian(&Bs) + return nil + } + + g, _ := errgroup.WithContext(context.TODO()) + // CPU MSM + g.Go(computeKRS1_CPU) + g.Go(computeBS2_CPU) + + // Serial GPU MSM + computeAR1() + computeBS1() + // computeKRS1() + computeKRS2() + // computeBS2() + if err := g.Wait(); err != nil { + return nil, err + } + + // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 + { + var p1 curve.G1Jac + krs1.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs1.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs1.AddAssign(&p1) + proof.Krs.FromJacobian(&krs1) + } + + log.Debug().Dur("took", time.Since(start)).Msg("prover done") + + return proof, nil +} + +// if len(toRemove) == 0, returns slice +// +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { + + if len(toRemove) == 0 { + return slice + } + + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) + for i := 0; i < len(slice); i++ { + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } + continue + } + r = append(r, slice[i]) + } + + return +} + +func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { + // H part of Krs + // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) + // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) + // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) + // 3 - h = ifft_coset(ca o cb - cc) + + n := len(a) + + // add padding to ensure input length is domain cardinality + padding := make([]fr.Element, int(domain.Cardinality)-n) + a = append(a, padding...) + b = append(b, padding...) + c = append(c, padding...) + n = len(a) + + // a -> aPoly, b -> bPoly, c -> cPoly + // point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) + + // evaluate aPoly, bPoly, cPoly on coset (roots of unity) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) + + // vanishing poly t(x) = x^N - 1 + // calcualte 1/t(g), where g is the generator + var den, one fr.Element + one.SetOne() + // g^N + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + // 1/(g^N - 1) + den.Sub(&den, &one).Inverse(&den) + + // h = (a*b - c)/t + // h = ifft_coset(ca o cb - cc) + // reusing a to avoid unnecessary memory allocation + utils.Parallelize(n, func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &b[i]). + Sub(&a[i], &c[i]). + Mul(&a[i], &den) + } + }) + + // ifft_coset: point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) + + return a +} + +// GPU Msm for either G1 or G2 points +func gpuMsm[R curve.G1Jac | curve.G2Jac, P curve.G1Affine | curve.G2Affine]( + res *R, + points *DevicePoints[P], + scalars *device.HostOrDeviceSlice[fr.Element], +) error { + // Check inputs + if !points.IsOnDevice() || !scalars.IsOnDevice() { + panic("points and scalars must be on device") + } + if points.Len() != scalars.Len() { + panic("points and scalars should be in the same length") + } + + // Setup MSM config + cfg := msm.DefaultMSMConfig() + cfg.AreInputPointInMont = points.Mont + cfg.AreInputsOnDevice = true + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true + cfg.Npoints = uint32(points.Len()) + cfg.LargeBucketFactor = 2 + + switch any(points).(type) { + case *DevicePoints[curve.G1Affine]: + resAffine := curve.G1Affine{} + if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G1Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G1Jac") + } + case *DevicePoints[curve.G2Affine]: + resAffine := curve.G2Affine{} + if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G2Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G2Jac") + } + default: + panic("invalid points type") + } + + // After GPU MSM, points in GPU are converted to affine form + points.Mont = false + + return nil +} + +func CopyToDevice[T any](hostData []T, chDeviceSlice chan *device.HostOrDeviceSlice[T]) error { + deviceSlice, err := device.CudaMalloc[T](deviceId, len(hostData)) + if err != nil { + chDeviceSlice <- nil + return err + } + if err := deviceSlice.CopyFromHost(hostData[:]); err != nil { + chDeviceSlice <- nil + return err + } + chDeviceSlice <- deviceSlice + return nil +} diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index a56b5730a3..6f16787fe6 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -51,6 +51,7 @@ import ( groth16_bls24317 "github.com/consensys/gnark/backend/groth16/bls24-317" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" icicle_bn254 "github.com/consensys/gnark/backend/groth16/bn254/icicle" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" groth16_bw6633 "github.com/consensys/gnark/backend/groth16/bw6-633" groth16_bw6761 "github.com/consensys/gnark/backend/groth16/bw6-761" ) @@ -198,6 +199,9 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness. return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), fullWitness, opts...) case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + return zeknox_bn254.Prove(_r1cs, pk.(*zeknox_bn254.ProvingKey), fullWitness, opts...) + } if icicle_bn254.HasIcicle { return icicle_bn254.Prove(_r1cs, pk.(*icicle_bn254.ProvingKey), fullWitness, opts...) } @@ -247,6 +251,13 @@ func Setup(r1cs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { return &pk, &vk, nil case *cs_bn254.R1CS: var vk groth16_bn254.VerifyingKey + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.Setup(_r1cs, &pk, &vk); err != nil { + return nil, nil, err + } + return &pk, &vk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.Setup(_r1cs, &pk, &vk); err != nil { @@ -309,6 +320,13 @@ func DummySetup(r1cs constraint.ConstraintSystem) (ProvingKey, error) { } return &pk, nil case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.DummySetup(_r1cs, &pk); err != nil { + return nil, err + } + return &pk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.DummySetup(_r1cs, &pk); err != nil { @@ -357,6 +375,9 @@ func NewProvingKey(curveID ecc.ID) ProvingKey { switch curveID { case ecc.BN254: pk = &groth16_bn254.ProvingKey{} + if zeknox_bn254.HasZeknox { + pk = &zeknox_bn254.ProvingKey{} + } if icicle_bn254.HasIcicle { pk = &icicle_bn254.ProvingKey{} } diff --git a/examples/zeknox/main.go b/examples/zeknox/main.go new file mode 100644 index 0000000000..bceab3edce --- /dev/null +++ b/examples/zeknox/main.go @@ -0,0 +1,152 @@ +// main.go + +package main + +import ( + "flag" + "log" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/hash/sha3" + "github.com/consensys/gnark/std/math/uints" + cryptosha3 "golang.org/x/crypto/sha3" +) + +type sha3Circuit struct { + In []uints.U8 `gnark:",secret"` + Expected [32]uints.U8 `gnark:",public"` +} + +func (c *sha3Circuit) Define(api frontend.API) error { + h, err := sha3.New256(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + + h.Write(c.In) + res := h.Sum() + + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +const inputLength = 128 + +func compileCircuit(newBuilder frontend.NewBuilder) (constraint.ConstraintSystem, error) { + circuit := sha3Circuit{ + In: make([]uints.U8, inputLength), + } + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), newBuilder, &circuit) + if err != nil { + return nil, err + } + return r1cs, nil +} + +func generateWitness() (witness.Witness, error) { + input := make([]byte, inputLength) + dgst := cryptosha3.Sum256(input) + witness := sha3Circuit{ + In: uints.NewU8Array(input[:]), + } + copy(witness.Expected[:], uints.NewU8Array(dgst[:])) + + witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + return witnessData, nil +} + +func main() { + logger.Disable() + + nRuns := flag.Int("r", 5, "number of runs") + flag.Parse() + log.Printf("Number of runs: %d", *nRuns) + + r1cs, err := compileCircuit(r1cs.NewBuilder) + if err != nil { + panic(err) + } + pk, vk, err := groth16.Setup(r1cs) + if err != nil { + panic(err) + } + + // Witness generation + witnessData, err := generateWitness() + if err != nil { + panic(err) + } + publicWitness, err := witnessData.Public() + if err != nil { + panic(err) + } + + // GPU Prove & Verify + // Warmup GPU + proofZeknox, err := groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) + if err != nil { + panic(err) + } + if err := groth16.Verify(proofZeknox, vk, publicWitness); err != nil { + panic(err) + } + // Actual run + tgpu := float64(0) + for i := 0; i < *nRuns; i++ { + start := time.Now() + proofZeknox, err = groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) + if err != nil { + panic(err) + } + tgpu += float64(time.Since(start).Milliseconds()) + if err := groth16.Verify(proofZeknox, vk, publicWitness); err != nil { + panic(err) + } + } + tgpu /= float64(*nRuns) + log.Printf("zeknox GPU prove average time: %v ms", tgpu) + + // CPU Prove & Verify + // Warmup CPU + proof, err := groth16.Prove(r1cs, pk, witnessData) + if err != nil { + panic(err) + } + if err := groth16.Verify(proof, vk, publicWitness); err != nil { + panic(err) + } + // Actual run + tcpu := float64(0) + for i := 0; i < *nRuns; i++ { + start := time.Now() + proof, err = groth16.Prove(r1cs, pk, witnessData) + if err != nil { + panic(err) + } + tcpu += float64(time.Since(start).Milliseconds()) + if err := groth16.Verify(proof, vk, publicWitness); err != nil { + panic(err) + } + } + tcpu /= float64(*nRuns) + log.Printf("CPU prove average time: %v ms", tcpu) + + log.Printf("Speedup: %f", tcpu/tgpu) +} diff --git a/go.mod b/go.mod index 2b2a34cd6e..95ee562819 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/consensys/gnark -go 1.22 +go 1.22.2 -toolchain go1.22.6 +toolchain go1.23.1 require ( github.com/bits-and-blooms/bitset v1.14.2 @@ -16,6 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 + github.com/okx/zeknox v1.0.0 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -31,7 +32,6 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/x448/float16 v0.8.4 // indirect golang.org/x/sys v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 9f971f37cf..21ce300c0e 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,6 @@ github.com/consensys/bavard v0.1.22 h1:Uw2CGvbXSZWhqK59X0VG/zOjpTFuOMcPLStrp1ihI github.com/consensys/bavard v0.1.22/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.14.1-0.20241010154951-6638408a49f3 h1:jVatckGR1s3OHs4QnGsppX+w2P3eedlWxi7ZFq56rjA= -github.com/consensys/gnark-crypto v0.14.1-0.20241010154951-6638408a49f3/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0 h1:uFZaZWG0FOoiFN3fAQzH2JXDuybdNwiJzBujy81YtU4= github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -232,6 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/okx/zeknox v1.0.0 h1:W/nZnaBIQjB5LHK2DsVdlL5v7NHP4OBPLziZ6gh/14U= +github.com/okx/zeknox v1.0.0/go.mod h1:zlHemJhkN7W22xWWtANF66oPdzUJYT1frlkdSZhLQbc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -243,8 +243,8 @@ github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndr github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/ronanh/intcomp v1.1.0 h1:i54kxmpmSoOZFcWPMWryuakN0vLxLswASsGa07zkvLU= github.com/ronanh/intcomp v1.1.0/go.mod h1:7FOLy3P3Zj3er/kVrU/pl+Ql7JFZj7bwliMGketo0IU= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=