From e0a9c5a63dbc5ce8fb2fb14394742331640ca64b Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:00:01 +0800 Subject: [PATCH 01/62] log msm g1 g2 time, and add comment --- backend/groth16/bn254/prove.go | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 5f0d413133..6b58202c60 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -138,7 +138,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil, err } - // H (witness reduction / FFT part) + // quotient poly H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { @@ -186,6 +186,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b if _, err := _s.SetRandom(); err != nil { return nil, err } + // -rs + // Why it is called kr? not rs? -> notation from DIZK paper _kr.Mul(&_r, &_s).Neg(&_kr) _r.BigInt(&r) @@ -201,11 +203,14 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chBs1Done := make(chan error, 1) computeBS1 := func() { <-chWireValuesB + startBs1 := time.Now() if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chBs1Done <- err close(chBs1Done) return } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValuesB)), time.Since(startBs1)).Msg("bs1.MultiExp done") + // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) chBs1Done <- nil @@ -214,11 +219,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chArDone := make(chan error, 1) computeAR1 := func() { <-chWireValuesA + startAr := time.Now() if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chArDone <- err close(chArDone) return } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValuesA)), time.Since(startAr)).Msg("ar.MultiExp done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) @@ -234,7 +241,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chKrs2Done := make(chan error, 1) sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { + startKrs2 := time.Now() _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") chKrs2Done <- err }() @@ -244,10 +253,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") + // -rs[δ] krs.AddMixed(&deltas[2]) n := 3 for n != 0 { @@ -290,9 +302,11 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b nbTasks *= 2 } <-chWireValuesB + startBs := time.Now() if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG2 %d took", len(wireValuesB)), time.Since(startBs)).Msg("Bs.MultiExp done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -369,19 +383,27 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) + // a -> aPoly, b -> bPoly, c -> cPoly + // point-value form -> coefficient form domain.FFTInverse(a, fft.DIF) domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) + // evaluate aPoly, bPoly, cPoly on coset (roots of unity) domain.FFT(a, fft.DIT, fft.OnCoset()) domain.FFT(b, fft.DIT, fft.OnCoset()) domain.FFT(c, fft.DIT, fft.OnCoset()) + // vanishing poly t(x) = x^N - 1 + // calcualte 1/t(g), where g is the generator var den, one fr.Element one.SetOne() + // g^N den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + // 1/(g^N - 1) den.Sub(&den, &one).Inverse(&den) + // h = (a*b - c)/t // h = ifft_coset(ca o cb - cc) // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { @@ -392,7 +414,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { } }) - // ifft_coset + // ifft_coset: point-value form -> coefficient form domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a From 48e8cc7e9299d7a0668e49cb053fa284288803b2 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:17:29 +0800 Subject: [PATCH 02/62] log computeH time --- backend/groth16/bn254/prove.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 6b58202c60..6aa0d93e30 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -142,7 +142,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var h []fr.Element chHDone := make(chan struct{}, 1) go func() { + startH := time.Now() h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") solution.A = nil solution.B = nil solution.C = nil From 0a93689b1ea82e8e6b2c72061a3e5d1256bc338b Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:13:18 +0800 Subject: [PATCH 03/62] init zeknox GPU acceleration --- backend/backend.go | 13 ++++ backend/groth16/bn254/zeknox/doc.go | 2 + backend/groth16/bn254/zeknox/marshal_test.go | 67 ++++++++++++++++++++ backend/groth16/bn254/zeknox/nozeknox.go | 18 ++++++ backend/groth16/bn254/zeknox/provingkey.go | 36 +++++++++++ backend/groth16/groth16.go | 21 ++++++ 6 files changed, 157 insertions(+) create mode 100644 backend/groth16/bn254/zeknox/doc.go create mode 100644 backend/groth16/bn254/zeknox/marshal_test.go create mode 100644 backend/groth16/bn254/zeknox/nozeknox.go create mode 100644 backend/groth16/bn254/zeknox/provingkey.go diff --git a/backend/backend.go b/backend/backend.go index 7c427e5825..acd89a0585 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -121,6 +121,19 @@ func WithProverKZGFoldingHashFunction(hFunc hash.Hash) ProverOption { } } +// WithZeknoxAcceleration requests to use [ZEKNOX] GPU proving backend for the +// prover. This option requires that the program is compiled with `zeknox` build +// tag and the ZEKNOX dependencies are properly installed. See [ZEKNOX] for +// installation description. +// +// [ZEKNOX]: https://github.com/okx/cryptography_cuda +func WithZeknoxAcceleration() ProverOption { + return func(pc *ProverConfig) error { + pc.Accelerator = "zeknox" + return nil + } +} + // WithIcicleAcceleration requests to use [ICICLE] GPU proving backend for the // prover. This option requires that the program is compiled with `icicle` build // tag and the ICICLE dependencies are properly installed. See [ICICLE] for diff --git a/backend/groth16/bn254/zeknox/doc.go b/backend/groth16/bn254/zeknox/doc.go new file mode 100644 index 0000000000..2200b550c7 --- /dev/null +++ b/backend/groth16/bn254/zeknox/doc.go @@ -0,0 +1,2 @@ +// Package zeknox_bn254 implements zeknox acceleration for BN254 Groth16 backend. +package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/marshal_test.go b/backend/groth16/bn254/zeknox/marshal_test.go new file mode 100644 index 0000000000..5e9b2aeaea --- /dev/null +++ b/backend/groth16/bn254/zeknox/marshal_test.go @@ -0,0 +1,67 @@ +package zeknox_bn254_test + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" + cs_bn254 "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type circuit struct { + A, B frontend.Variable `gnark:",public"` + Res frontend.Variable +} + +func (c *circuit) Define(api frontend.API) error { + api.AssertIsEqual(api.Mul(c.A, c.B), c.Res) + return nil +} + +func TestMarshalNative(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + nativePK := groth16_bn254.ProvingKey{} + nativeVK := groth16_bn254.VerifyingKey{} + err = groth16_bn254.Setup(tCcs, &nativePK, &nativeVK) + assert.NoError(err) + + pk := groth16.NewProvingKey(ecc.BN254) + buf := new(bytes.Buffer) + _, err = nativePK.WriteTo(buf) + assert.NoError(err) + _, err = pk.ReadFrom(buf) + assert.NoError(err) + if pk.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} + +func TestMarshalZeknox(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + zePK := zeknox_bn254.ProvingKey{} + VK := groth16_bn254.VerifyingKey{} + err = zeknox_bn254.Setup(tCcs, &zePK, &VK) + assert.NoError(err) + + nativePK := groth16_bn254.ProvingKey{} + buf := new(bytes.Buffer) + _, err = zePK.WriteTo(buf) + assert.NoError(err) + _, err = nativePK.ReadFrom(buf) + assert.NoError(err) + if zePK.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go new file mode 100644 index 0000000000..a1c94bb97b --- /dev/null +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -0,0 +1,18 @@ +//go:build !zeknox + +package zeknox_bn254 + +import ( + "fmt" + + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/witness" + cs "github.com/consensys/gnark/constraint/bn254" +) + +const HasZeknox = false + +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + return nil, fmt.Errorf("zeknox backend requested but program compiled without 'zeknox' build tag") +} diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go new file mode 100644 index 0000000000..b13c59b0d1 --- /dev/null +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -0,0 +1,36 @@ +package zeknox_bn254 + +import ( + "unsafe" + + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + cs "github.com/consensys/gnark/constraint/bn254" +) + +type deviceInfo struct { + G1Device struct { + A, B, K, Z unsafe.Pointer + } + DomainDevice struct { + Twiddles, TwiddlesInv unsafe.Pointer + CosetTable, CosetTableInv unsafe.Pointer + } + G2Device struct { + B unsafe.Pointer + } + DenDevice unsafe.Pointer + InfinityPointIndicesK []int +} + +type ProvingKey struct { + groth16_bn254.ProvingKey + *deviceInfo +} + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index a56b5730a3..6f16787fe6 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -51,6 +51,7 @@ import ( groth16_bls24317 "github.com/consensys/gnark/backend/groth16/bls24-317" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" icicle_bn254 "github.com/consensys/gnark/backend/groth16/bn254/icicle" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" groth16_bw6633 "github.com/consensys/gnark/backend/groth16/bw6-633" groth16_bw6761 "github.com/consensys/gnark/backend/groth16/bw6-761" ) @@ -198,6 +199,9 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness. return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), fullWitness, opts...) case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + return zeknox_bn254.Prove(_r1cs, pk.(*zeknox_bn254.ProvingKey), fullWitness, opts...) + } if icicle_bn254.HasIcicle { return icicle_bn254.Prove(_r1cs, pk.(*icicle_bn254.ProvingKey), fullWitness, opts...) } @@ -247,6 +251,13 @@ func Setup(r1cs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { return &pk, &vk, nil case *cs_bn254.R1CS: var vk groth16_bn254.VerifyingKey + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.Setup(_r1cs, &pk, &vk); err != nil { + return nil, nil, err + } + return &pk, &vk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.Setup(_r1cs, &pk, &vk); err != nil { @@ -309,6 +320,13 @@ func DummySetup(r1cs constraint.ConstraintSystem) (ProvingKey, error) { } return &pk, nil case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.DummySetup(_r1cs, &pk); err != nil { + return nil, err + } + return &pk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.DummySetup(_r1cs, &pk); err != nil { @@ -357,6 +375,9 @@ func NewProvingKey(curveID ecc.ID) ProvingKey { switch curveID { case ecc.BN254: pk = &groth16_bn254.ProvingKey{} + if zeknox_bn254.HasZeknox { + pk = &zeknox_bn254.ProvingKey{} + } if icicle_bn254.HasIcicle { pk = &icicle_bn254.ProvingKey{} } From e56658859a01ab07e2b8a98c0bb60aad571fc654 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 16 Oct 2024 14:32:17 +0800 Subject: [PATCH 04/62] MSM G1 & G2 acclerating! with local cuda repo --- backend/groth16/bn254/zeknox/provingkey.go | 6 +- backend/groth16/bn254/zeknox/zeknox.go | 552 +++++++++++++++++++++ go.mod | 8 +- go.sum | 4 +- 4 files changed, 563 insertions(+), 7 deletions(-) create mode 100644 backend/groth16/bn254/zeknox/zeknox.go diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index b13c59b0d1..f54a97a669 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -4,19 +4,21 @@ import ( "unsafe" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254" cs "github.com/consensys/gnark/constraint/bn254" + "github.com/okx/cryptography_cuda/wrappers/go/device" ) type deviceInfo struct { G1Device struct { - A, B, K, Z unsafe.Pointer + A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] } DomainDevice struct { Twiddles, TwiddlesInv unsafe.Pointer CosetTable, CosetTableInv unsafe.Pointer } G2Device struct { - B unsafe.Pointer + B *device.HostOrDeviceSlice[bn254.G2Affine] } DenDevice unsafe.Pointer InfinityPointIndicesK []int diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go new file mode 100644 index 0000000000..e49c9aa031 --- /dev/null +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -0,0 +1,552 @@ +//go:build zeknox + +package zeknox_bn254 + +import ( + "context" + "fmt" + "math/big" + "runtime" + "time" + "unsafe" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field" + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/constraint/solver" + fcs "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + "github.com/okx/cryptography_cuda/wrappers/go/device" + "github.com/okx/cryptography_cuda/wrappers/go/msm" + "golang.org/x/sync/errgroup" +) + +const HasZeknox = true + +// Use single GPU +const deviceId = 0 + +func (pk *ProvingKey) setupDevicePointers() error { + if pk.deviceInfo != nil { + return nil + } + pk.deviceInfo = &deviceInfo{} + // TODO: setup FFT + + // MSM G1 & G2 Device Setup + g, _ := errgroup.WithContext(context.TODO()) + // G1.A + deviceA := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.A, deviceA) }) + + // G1.B + deviceG1B := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.B, deviceG1B) }) + + // G1.K + var pointsNoInfinity []curve.G1Affine + for i, gnarkPoint := range pk.G1.K { + if gnarkPoint.IsInfinity() { + pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i) + } else { + pointsNoInfinity = append(pointsNoInfinity, gnarkPoint) + } + } + deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) + + // G1.Z + deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.Z, deviceZ) }) + + // G2.B + deviceG2B := make(chan *device.HostOrDeviceSlice[curve.G2Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G2.B, deviceG2B) }) + + // wait for all points to be copied to the device + // if any of the copy failed, return the error + if err := g.Wait(); err != nil { + return err + } + // if no error, store device pointers in pk + pk.G1Device.A = <-deviceA + pk.G1Device.B = <-deviceG1B + pk.G1Device.K = <-deviceK + pk.G1Device.Z = <-deviceZ + pk.G2Device.B = <-deviceG2B + + return nil +} + +// Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + fmt.Println("zeknox_bn254.Prove") + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, fmt.Errorf("new prover config: %w", err) + } + if opt.HashToFieldFn == nil { + opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst)) + } + if opt.Accelerator != "zeknox" { + return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) + } + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + if pk.deviceInfo == nil { + log.Debug().Msg("precomputing proving key in GPU") + if err := pk.setupDevicePointers(); err != nil { + return nil, fmt.Errorf("setup device pointers: %w", err) + } + } + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &groth16_bn254.Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] + + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + + // override hints + bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder) + solverOpts = append(solverOpts, solver.OverrideHint(bsb22ID, func(_ *big.Int, in []*big.Int, out []*big.Int) error { + i := int(in[0].Int64()) + in = in[1:] + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[+len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1)) + hashBts := opt.HashToFieldFn.Sum(nil) + opt.HashToFieldFn.Reset() + nbBuf := fr.Bytes + if opt.HashToFieldFn.Size() < fr.Bytes { + nbBuf = opt.HashToFieldFn.Size() + } + var res fr.Element + res.SetBytes(hashBts[:nbBuf]) + res.BigInt(out[0]) + return nil + })) + + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err + } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + + start := time.Now() + poks := make([]curve.G1Affine, len(pk.CommitmentKeys)) + + for i := range pk.CommitmentKeys { + var err error + if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil { + return nil, err + } + } + // compute challenge for folding the PoKs from the commitments + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1) + if err != nil { + return nil, err + } + if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil { + return nil, err + } + + // quotient poly H (witness reduction / FFT part) + var h []fr.Element + chHDone := make(chan struct{}, 1) + go func() { + startH := time.Now() + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") + solution.A = nil + solution.B = nil + solution.C = nil + chHDone <- struct{}{} + }() + + // we need to copy and filter the wireValues for each multi exp + // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity + var deviceWireValuesA, deviceWireValuesB *device.HostOrDeviceSlice[fr.Element] + // indicate if the wire values have been copied to the device + chWireValuesA, chWireValuesB := make(chan error, 1), make(chan error, 1) + + go func() { + wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + for i, j := 0, 0; j < len(wireValuesA); i++ { + if pk.InfinityA[i] { + continue + } + wireValuesA[j] = wireValues[i] + j++ + } + chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { + chWireValuesA <- err + close(chWireValuesA) + return + } + deviceWireValuesA = <-chDeviceValues + close(chWireValuesA) + }() + go func() { + wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + for i, j := 0, 0; j < len(wireValuesB); i++ { + if pk.InfinityB[i] { + continue + } + wireValuesB[j] = wireValues[i] + j++ + } + chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { + chWireValuesB <- err + close(chWireValuesB) + return + } + deviceWireValuesB = <-chDeviceValues + close(chWireValuesB) + }() + + // sample random r and s + var r, s big.Int + var _r, _s, _kr fr.Element + if _, err := _r.SetRandom(); err != nil { + return nil, err + } + if _, err := _s.SetRandom(); err != nil { + return nil, err + } + // -rs + // Why it is called kr? not rs? -> notation from DIZK paper + _kr.Mul(&_r, &_s).Neg(&_kr) + + _r.BigInt(&r) + _s.BigInt(&s) + + // computes r[δ], s[δ], kr[δ] + deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) + + var bs1, ar curve.G1Jac + + n := runtime.NumCPU() + + chBs1Done := make(chan error, 1) + + computeBS1 := func() { + if err := <-chWireValuesB; err != nil { + chBs1Done <- err + close(chBs1Done) + return + } + startBs1 := time.Now() + if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { + chBs1Done <- err + close(chBs1Done) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1.MultiExp done") + // + beta + s[δ] + bs1.AddMixed(&pk.G1.Beta) + bs1.AddMixed(&deltas[1]) + chBs1Done <- nil + } + + chArDone := make(chan error, 1) + computeAR1 := func() { + if err := <-chWireValuesA; err != nil { + chArDone <- err + close(chArDone) + return + } + startAr := time.Now() + if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { + chArDone <- err + close(chArDone) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar.MultiExp done") + ar.AddMixed(&pk.G1.Alpha) + ar.AddMixed(&deltas[0]) + proof.Ar.FromJacobian(&ar) + chArDone <- nil + } + + chKrsDone := make(chan error, 1) + var deviceH *device.HostOrDeviceSlice[fr.Element] + computeKRS := func() { + // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z + // however, having similar lengths for our tasks helps with parallelism + + var krs, krs2, p1 curve.G1Jac + chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + go func() { + startKrs2 := time.Now() + // Copy h poly to device, since we haven't implemented FFT on device + chDevice := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(h[:sizeH], chDevice); err != nil { + chKrs2Done <- err + close(chKrs2Done) + return + } + deviceH = <-chDevice + if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + chKrs2Done <- err + close(chKrs2Done) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") + chKrs2Done <- err + }() + + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + + startKrs := time.Now() + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + chKrsDone <- err + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") + // -rs[δ] + krs.AddMixed(&deltas[2]) + n := 3 + for n != 0 { + select { + case err := <-chKrs2Done: + if err != nil { + chKrsDone <- err + return + } + krs.AddAssign(&krs2) + case err := <-chArDone: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + case err := <-chBs1Done: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) + } + n-- + } + + proof.Krs.FromJacobian(&krs) + chKrsDone <- nil + } + + computeBS2 := func() error { + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac + + nbTasks := n + if nbTasks <= 16 { + // if we don't have a lot of CPUs, this may artificially split the MSM + nbTasks *= 2 + } + <-chWireValuesB + startBs := time.Now() + if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs.MultiExp done") + + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) + + proof.Bs.FromJacobian(&Bs) + return nil + } + + // wait for FFT to end, as it uses all our CPUs + <-chHDone + + // schedule our proof part computations + go computeKRS() + go computeAR1() + go computeBS1() + if err := computeBS2(); err != nil { + return nil, err + } + + // wait for all parts of the proof to be computed. + if err := <-chKrsDone; err != nil { + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("prover done") + + // Free device memory + go func() { + deviceWireValuesA.Free() + deviceWireValuesB.Free() + deviceH.Free() + }() + + return proof, nil +} + +// if len(toRemove) == 0, returns slice +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { + + if len(toRemove) == 0 { + return slice + } + + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) + for i := 0; i < len(slice); i++ { + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } + continue + } + r = append(r, slice[i]) + } + + return +} + +func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { + // H part of Krs + // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) + // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) + // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) + // 3 - h = ifft_coset(ca o cb - cc) + + n := len(a) + + // add padding to ensure input length is domain cardinality + padding := make([]fr.Element, int(domain.Cardinality)-n) + a = append(a, padding...) + b = append(b, padding...) + c = append(c, padding...) + n = len(a) + + // a -> aPoly, b -> bPoly, c -> cPoly + // point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) + + // evaluate aPoly, bPoly, cPoly on coset (roots of unity) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) + + // vanishing poly t(x) = x^N - 1 + // calcualte 1/t(g), where g is the generator + var den, one fr.Element + one.SetOne() + // g^N + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + // 1/(g^N - 1) + den.Sub(&den, &one).Inverse(&den) + + // h = (a*b - c)/t + // h = ifft_coset(ca o cb - cc) + // reusing a to avoid unnecessary memory allocation + utils.Parallelize(n, func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &b[i]). + Sub(&a[i], &c[i]). + Mul(&a[i], &den) + } + }) + + // ifft_coset: point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) + + return a +} + +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + if points.Len() != scalars.Len() { + return fmt.Errorf("MSM: len(points) != len(scalars)") + } + cfg := msm.DefaultMSMConfig() + cfg.ArePointsInMont = true + cfg.Npoints = uint32(points.Len()) + cfg.FfiAffineSz = 64 + if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + return nil +} + +func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + if points.Len() != scalars.Len() { + return fmt.Errorf("MSM: len(points) != len(scalars)") + } + cfg := msm.DefaultMSMConfig() + cfg.AreInputsOnDevice = true + cfg.ArePointsInMont = true + cfg.Npoints = uint32(points.Len()) + cfg.LargeBucketFactor = 2 + // TODO: MSM_G2 should return Jacobian + // https://github.com/okx/cryptography_cuda/issues/90 + resAffine := curve.G2Affine{} + if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + res.FromAffine(&resAffine) + return nil +} + +func CopyToDevice[T any](hostData []T, chDeviceSlice chan *device.HostOrDeviceSlice[T]) error { + deviceSlice, err := device.CudaMalloc[T](deviceId, len(hostData)) + if err != nil { + chDeviceSlice <- nil + return err + } + if err := deviceSlice.CopyFromHost(hostData[:]); err != nil { + chDeviceSlice <- nil + return err + } + chDeviceSlice <- deviceSlice + return nil +} diff --git a/go.mod b/go.mod index 4ed242d670..e643586c0c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/consensys/gnark -go 1.22 +go 1.22.2 -toolchain go1.22.6 +toolchain go1.23.1 require ( github.com/bits-and-blooms/bitset v1.14.2 @@ -16,6 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 + github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -31,9 +32,10 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/x448/float16 v0.8.4 // indirect golang.org/x/sys v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go diff --git a/go.sum b/go.sum index efc71ddf96..81dc4a651b 100644 --- a/go.sum +++ b/go.sum @@ -241,8 +241,8 @@ github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndr github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/ronanh/intcomp v1.1.0 h1:i54kxmpmSoOZFcWPMWryuakN0vLxLswASsGa07zkvLU= github.com/ronanh/intcomp v1.1.0/go.mod h1:7FOLy3P3Zj3er/kVrU/pl+Ql7JFZj7bwliMGketo0IU= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= From ffd27815a0e28f2f842143fcac1cbcf053efd6a6 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 17 Oct 2024 11:53:09 +0800 Subject: [PATCH 05/62] sequencial GPU MSM & refactor --- backend/groth16/bn254/zeknox/nozeknox.go | 2 +- backend/groth16/bn254/zeknox/zeknox.go | 97 +++++++++++++++--------- go.mod | 2 +- go.sum | 2 + 4 files changed, 64 insertions(+), 39 deletions(-) diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index a1c94bb97b..8859d6f319 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -1,4 +1,4 @@ -//go:build !zeknox +//go:build zeknox package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index e49c9aa031..df967e908d 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -1,4 +1,4 @@ -//go:build zeknox +//go:build !zeknox package zeknox_bn254 @@ -89,7 +89,6 @@ func (pk *ProvingKey) setupDevicePointers() error { // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { - fmt.Println("zeknox_bn254.Prove") opt, err := backend.NewProverConfig(opts...) if err != nil { return nil, fmt.Errorf("new prover config: %w", err) @@ -102,10 +101,11 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() if pk.deviceInfo == nil { - log.Debug().Msg("precomputing proving key in GPU") + start := time.Now() if err := pk.setupDevicePointers(); err != nil { return nil, fmt.Errorf("setup device pointers: %w", err) } + log.Debug().Dur("took", time.Since(start)).Msg("Copy proving key to device") } commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) @@ -182,7 +182,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b go func() { startH := time.Now() h = computeH(solution.A, solution.B, solution.C, &pk.Domain) - log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") + log.Debug().Dur("took", time.Since(startH)).Msg("computed H") solution.A = nil solution.B = nil solution.C = nil @@ -207,7 +207,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { chWireValuesA <- err - close(chWireValuesA) return } deviceWireValuesA = <-chDeviceValues @@ -225,7 +224,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { chWireValuesB <- err - close(chWireValuesB) return } deviceWireValuesB = <-chDeviceValues @@ -253,23 +251,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - n := runtime.NumCPU() - chBs1Done := make(chan error, 1) computeBS1 := func() { if err := <-chWireValuesB; err != nil { chBs1Done <- err - close(chBs1Done) return } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { chBs1Done <- err - close(chBs1Done) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) @@ -280,16 +274,14 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b computeAR1 := func() { if err := <-chWireValuesA; err != nil { chArDone <- err - close(chArDone) return } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { chArDone <- err - close(chArDone) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) @@ -304,40 +296,56 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) - sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { startKrs2 := time.Now() // Copy h poly to device, since we haven't implemented FFT on device - chDevice := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(h[:sizeH], chDevice); err != nil { + chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { chKrs2Done <- err - close(chKrs2Done) return } - deviceH = <-chDevice + deviceH = <-chDeviceH if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { chKrs2Done <- err - close(chKrs2Done) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") - chKrs2Done <- err + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") + close(chKrs2Done) }() // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) - _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + // original Groth16 witness without pedersen commitment + wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) startKrs := time.Now() - if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + // GPU runtime error + // var deviceWire *device.HostOrDeviceSlice[fr.Element] + // defer deviceWire.Free() + // chDeviceWire := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + // if err := CopyToDevice(wireValuesWithoutCom, chDeviceWire); err != nil { + // chKrsDone <- err + // return + // } + // deviceWire = <-chDeviceWire + // if err := msmG1(&krs, pk.G1Device.K, deviceWire); err != nil { + // chKrsDone <- err + // return + // } + + // CPU + // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { chKrsDone <- err return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) + n := 3 for n != 0 { select { @@ -373,17 +381,14 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - nbTasks := n - if nbTasks <= 16 { - // if we don't have a lot of CPUs, this may artificially split the MSM - nbTasks *= 2 + if err := <-chWireValuesB; err != nil { + return err } - <-chWireValuesB startBs := time.Now() if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { return err } - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -398,17 +403,34 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b <-chHDone // schedule our proof part computations - go computeKRS() - go computeAR1() - go computeBS1() - if err := computeBS2(); err != nil { + // Sequencial GPU execution + // TODO: see GPU utilization data + computeAR1() + if err := <-chArDone; err != nil { return nil, err } - - // wait for all parts of the proof to be computed. + computeBS1() + if err := <-chBs1Done; err != nil { + return nil, err + } + computeKRS() if err := <-chKrsDone; err != nil { return nil, err } + if err := computeBS2(); err != nil { + return nil, err + } + + // Parallel GPU execution, memory may hit limit + // go computeKRS() + // go computeAR1() + // go computeBS1() + // go computeBS2() + + // wait for all parts of the proof to be computed. + // if err := <-chKrsDone; err != nil { + // return nil, err + // } log.Debug().Dur("took", time.Since(start)).Msg("prover done") @@ -423,6 +445,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // if len(toRemove) == 0, returns slice +// // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex // this assumes len(slice) > len(toRemove) // filterHeap modifies toRemove diff --git a/go.mod b/go.mod index e643586c0c..52c8797024 100644 --- a/go.mod +++ b/go.mod @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +// replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go diff --git a/go.sum b/go.sum index 81dc4a651b..fc472ffae3 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e h1:NT/U7+AJ93s0U4af9I5fEtpE33Etf68wEUif7Q/s1mo= +github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e/go.mod h1:y9SSivg7t0Fs0PZQJ/l2jUhWT67SeEj9XYgz5ysjyEw= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From e26498afb01aa4f501b79168e6f1c24df034b3ed Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 17 Oct 2024 15:18:42 +0800 Subject: [PATCH 06/62] mimc test gpu acceleration --- examples/mimc/mimc_test.go | 7 +++++++ go.mod | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/mimc/mimc_test.go b/examples/mimc/mimc_test.go index 5583193bf6..7631974ad8 100644 --- a/examples/mimc/mimc_test.go +++ b/examples/mimc/mimc_test.go @@ -18,9 +18,11 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/test" ) +// go test github.com/consensys/gnark/examples/mimc -tags=prover_checks func TestPreimage(t *testing.T) { assert := test.NewAssert(t) @@ -36,4 +38,9 @@ func TestPreimage(t *testing.T) { Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", }, test.WithCurves(ecc.BN254)) + assert.ProverSucceeded(&mimcCircuit, &Circuit{ + PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", + Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", + }, test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) + } diff --git a/go.mod b/go.mod index 52c8797024..e643586c0c 100644 --- a/go.mod +++ b/go.mod @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -// replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go From e6eb42e7607d2740921806550687c41dda1b8026 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:27:54 +0800 Subject: [PATCH 07/62] fix verify bug, delete channel --- backend/groth16/bn254/zeknox/zeknox.go | 86 +++++++------------------- 1 file changed, 22 insertions(+), 64 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index df967e908d..780fbb2a23 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -251,46 +251,38 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - chBs1Done := make(chan error, 1) - - computeBS1 := func() { + computeBS1 := func() error { if err := <-chWireValuesB; err != nil { - chBs1Done <- err - return + return err } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - chBs1Done <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - chBs1Done <- nil + return nil } - chArDone := make(chan error, 1) - computeAR1 := func() { + computeAR1 := func() error { if err := <-chWireValuesA; err != nil { - chArDone <- err - return + return err } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - chArDone <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - chArDone <- nil + return nil } - chKrsDone := make(chan error, 1) var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() { + computeKRS := func() error { // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z // however, having similar lengths for our tasks helps with parallelism @@ -321,7 +313,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // original Groth16 witness without pedersen commitment wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - startKrs := time.Now() // GPU runtime error // var deviceWire *device.HostOrDeviceSlice[fr.Element] // defer deviceWire.Free() @@ -338,43 +329,25 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // CPU // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - chKrsDone <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) - n := 3 - for n != 0 { - select { - case err := <-chKrs2Done: - if err != nil { - chKrsDone <- err - return - } - krs.AddAssign(&krs2) - case err := <-chArDone: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - case err := <-chBs1Done: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) - } - n-- + if err := <-chKrs2Done; err != nil { + return err } + krs.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) proof.Krs.FromJacobian(&krs) - chKrsDone <- nil + return nil } computeBS2 := func() error { @@ -405,33 +378,18 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // schedule our proof part computations // Sequencial GPU execution // TODO: see GPU utilization data - computeAR1() - if err := <-chArDone; err != nil { + if err := computeAR1(); err != nil { return nil, err } - computeBS1() - if err := <-chBs1Done; err != nil { + if err := computeBS1(); err != nil { return nil, err } - computeKRS() - if err := <-chKrsDone; err != nil { + if err := computeKRS(); err != nil { return nil, err } if err := computeBS2(); err != nil { return nil, err } - - // Parallel GPU execution, memory may hit limit - // go computeKRS() - // go computeAR1() - // go computeBS1() - // go computeBS2() - - // wait for all parts of the proof to be computed. - // if err := <-chKrsDone; err != nil { - // return nil, err - // } - log.Debug().Dur("took", time.Since(start)).Msg("prover done") // Free device memory From 31479f7118aa3e19c4b1062c6dab4375df0602c9 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:10:18 +0800 Subject: [PATCH 08/62] add p256 example for testing --- .gitignore | 3 +- examples/main.go | 8 ++ examples/p256/circuit.go | 111 ++++++++++++++++ examples/p256/p256.go | 255 +++++++++++++++++++++++++++++++++++++ examples/p256/p256_test.go | 17 +++ 5 files changed, 393 insertions(+), 1 deletion(-) create mode 100644 examples/main.go create mode 100644 examples/p256/circuit.go create mode 100644 examples/p256/p256.go create mode 100644 examples/p256/p256_test.go diff --git a/.gitignore b/.gitignore index 1be9af1ba2..8abda9a03b 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,5 @@ gnarkd/circuits/** go.work go.work.sum -examples/gbotrel/** \ No newline at end of file +examples/gbotrel/** +build/ \ No newline at end of file diff --git a/examples/main.go b/examples/main.go new file mode 100644 index 0000000000..6e438873e6 --- /dev/null +++ b/examples/main.go @@ -0,0 +1,8 @@ +package main + +import "github.com/consensys/gnark/examples/p256" + +func main() { + p256.Groth16Setup("build/") + p256.Groth16Prove("build/") +} \ No newline at end of file diff --git a/examples/p256/circuit.go b/examples/p256/circuit.go new file mode 100644 index 0000000000..c6613beb89 --- /dev/null +++ b/examples/p256/circuit.go @@ -0,0 +1,111 @@ +package p256 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/hash/sha3" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/uints" +) + +type EcdsaCircuit[T, S emulated.FieldParams] struct { + Commitment frontend.Variable `gnark:",public"` // Keccak256(Pub[0], Msg[0], Sig[1], Msg[1], ...)[1:32], ignore the first byte, since BN254 order < uint256 + + Pub [NumSignatures]PublicKey[T, S] `gnark:",secret"` + Msg [NumSignatures]emulated.Element[S] `gnark:",secret"` + Sig [NumSignatures]Signature[S] `gnark:",secret"` +} + +func (c *EcdsaCircuit[T, S]) Define(api frontend.API) error { + // Verify all ECDSA-P256 signatures + for i := range c.Sig { + c.Pub[i].Verify(api, sw_emulated.GetCurveParams[T](), &c.Msg[i], &c.Sig[i]) + } + // Keccak256 Commit to all signatures + h, err := sha3.NewLegacyKeccak256(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + + var tInstance T + var sInstance S + perSignatureHashSize := 2*tInstance.NbLimbs() + sInstance.NbLimbs() + + hashIn := make([]uints.U8, 0, NumSignatures*perSignatureHashSize) + for i := 0; i < NumSignatures; i++ { + // hashIn += Pub[i].X + // Pay attention to the ordering! + for j := len(c.Pub[i].X.Limbs) - 1; j >= 0; j-- { + pubXLimb := uapi.UnpackMSB(uapi.ValueOf(c.Pub[i].X.Limbs[j])) + hashIn = append(hashIn, pubXLimb[:]...) + } + // hashIn += Pub[i].Y + for j := len(c.Pub[i].X.Limbs) - 1; j >= 0; j-- { + pubYLimb := uapi.UnpackMSB(uapi.ValueOf(c.Pub[i].Y.Limbs[j])) + hashIn = append(hashIn, pubYLimb[:]...) + } + // hashIn += Msg[i] + for j := len(c.Msg[i].Limbs) - 1; j >= 0; j-- { + msgLimb := uapi.UnpackMSB(uapi.ValueOf(c.Msg[i].Limbs[j])) + hashIn = append(hashIn, msgLimb[:]...) + } + } + h.Write(hashIn) + hashOutU8 := h.Sum() // Keccak256(Pub[0], Msg[0], Sig[1], Msg[1], ...)[0:32] + + // Commitment = hashoutU8[1:32] + hashOutU8[0] = uints.NewU8(0) // ignore the first byte, since BN254 order < uint256 + // Big endian [32]bytes to BigInt + for i := range hashOutU8 { + index := len(hashOutU8) - i - 1 + c.Commitment = api.MulAcc(c.Commitment, hashOutU8[index].Val, 1<<(i*8)) + } + return nil +} + +// Signature represents the signature for some message. +type Signature[Scalar emulated.FieldParams] struct { + R, S emulated.Element[Scalar] +} + +// PublicKey represents the public key to verify the signature for. +type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base] + +// Verify asserts that the signature sig verifies for the message msg and public +// key pk. The curve parameters params define the elliptic curve. +// +// We assume that the message msg is already hashed to the scalar field. +func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { + cr, err := sw_emulated.New[T, S](api, params) + if err != nil { + panic(err) + } + scalarApi, err := emulated.NewField[S](api) + if err != nil { + panic(err) + } + baseApi, err := emulated.NewField[T](api) + if err != nil { + panic(err) + } + pkpt := sw_emulated.AffinePoint[T](pk) + sInv := scalarApi.Inverse(&sig.S) + msInv := scalarApi.MulMod(msg, sInv) + rsInv := scalarApi.MulMod(&sig.R, sInv) + + // q = [rsInv]pkpt + [msInv]g + q := cr.JointScalarMulBase(&pkpt, rsInv, msInv) + qx := baseApi.Reduce(&q.X) + qxBits := baseApi.ToBits(qx) + rbits := scalarApi.ToBits(&sig.R) + if len(rbits) != len(qxBits) { + panic("non-equal lengths") + } + for i := range rbits { + api.AssertIsEqual(rbits[i], qxBits[i]) + } +} diff --git a/examples/p256/p256.go b/examples/p256/p256.go new file mode 100644 index 0000000000..cc580160e6 --- /dev/null +++ b/examples/p256/p256.go @@ -0,0 +1,255 @@ +package p256 + +import ( + cryptoecdsa "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "io" + "log" + "math/big" + "os" + "strconv" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/solidity" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + gnark_io "github.com/consensys/gnark/io" + "github.com/consensys/gnark/std/math/emulated" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/sha3" +) + +const NumSignatures = 1 + +var circuitName string + +func init() { + circuitName = "p256-" + strconv.Itoa(NumSignatures) +} + +func compileCircuit(newBuilder frontend.NewBuilder) (constraint.ConstraintSystem, error) { + circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), newBuilder, &circuit) + if err != nil { + return nil, err + } + return r1cs, nil +} + +func generateWitnessCircuit() EcdsaCircuit[emulated.P256Fp, emulated.P256Fr] { + witness := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} + perSignatureHashSize := 2*emulated.P256Fp{}.NbLimbs() + emulated.P256Fr{}.NbLimbs() + hashIn := make([]byte, 0, NumSignatures*perSignatureHashSize) + for i := 0; i < NumSignatures; i++ { + // Keygen + privKey, _ := cryptoecdsa.GenerateKey(elliptic.P256(), rand.Reader) + publicKey := privKey.PublicKey + + // Sign + msg, err := genRandomBytes(i + 20) + if err != nil { + panic(err) + } + msgHash := keccak256(msg) + sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) + + // Try verify + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sigBin) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + panic("invalid sig") + } + flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) + if !flag { + println("can't verify signature") + } + + // hashIn += Pub[i].X + Pub[i].Y + Msg[i] + pubX := publicKey.X.Bytes() + pubY := publicKey.Y.Bytes() + // println("pubX:", hex.EncodeToString(pubX)) + // println("pubY:", hex.EncodeToString(pubY)) + // println("msgHash:", hex.EncodeToString(msgHash[:])) + hashIn = append(hashIn, pubX[:]...) + hashIn = append(hashIn, pubY[:]...) + hashIn = append(hashIn, msgHash[:]...) + // Assign to circuit witness + witness.Sig[i] = Signature[emulated.P256Fr]{ + R: emulated.ValueOf[emulated.P256Fr](r), + S: emulated.ValueOf[emulated.P256Fr](s), + } + witness.Msg[i] = emulated.ValueOf[emulated.P256Fr](msgHash[:]) + witness.Pub[i] = PublicKey[emulated.P256Fp, emulated.P256Fr]{ + X: emulated.ValueOf[emulated.P256Fp](publicKey.X), + Y: emulated.ValueOf[emulated.P256Fp](publicKey.Y), + } + } + hashOut := keccak256(hashIn) + hashOut[0] = 0 // ignore the first byte, since BN254 order < uint256 + // println("hashOut:", hex.EncodeToString(hashOut[:])) + witness.Commitment = hashOut[:] + return witness +} + +func generateWitness() (witness.Witness, error) { + witness := generateWitnessCircuit() + witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + + return witnessData, nil +} + +func Groth16Setup(fileDir string) { + r1cs, err := compileCircuit(r1cs.NewBuilder) + if err != nil { + panic(err) + } + pk, vk, err := groth16.Setup(r1cs) + if err != nil { + panic(err) + } + // Write to file + if _, err := os.Stat(fileDir); os.IsNotExist(err) { + err := os.MkdirAll(fileDir, os.ModePerm) + if err != nil { + panic(err) + } + } + WriteToFile(pk, fileDir+circuitName+".zkey") + WriteToFile(r1cs, fileDir+circuitName+".r1cs") + WriteToFile(vk, fileDir+circuitName+".vkey") +} + +func Groth16Prove(fileDir string) { + // proveStart := time.Now() + // Witness generation + start := time.Now() + witnessData, err := generateWitness() + if err != nil { + panic(err) + } + elapsed := time.Since(start) + log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) + + // Read files + start = time.Now() + r1cs := groth16.NewCS(ecc.BN254) + ReadFromFile(r1cs, fileDir+circuitName+".r1cs") + elapsed = time.Since(start) + log.Printf("Read r1cs: %d ms", elapsed.Milliseconds()) + + start = time.Now() + pk := groth16.NewProvingKey(ecc.BN254) + + UnsafeReadFromFile(pk, fileDir+circuitName+".zkey") + elapsed = time.Since(start) + log.Printf("Read zkey: %d ms", elapsed.Milliseconds()) + + // Proof generation & verification + publicWitness, err := witnessData.Public() + if err != nil { + panic(err) + } + vk := groth16.NewVerifyingKey(ecc.BN254) + ReadFromFile(vk, fileDir+circuitName+".vkey") + + // CPU + for i := 0; i < 1; i++ { + fmt.Printf("------ CPU Prove %d ------", i+1) + start = time.Now() + proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16)) + if err != nil { + panic(err) + } + elapsed = time.Since(start) + log.Printf("CPU Prove: %d ms", elapsed.Milliseconds()) + if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { + panic(err) + } + } + + // GPU + for i := 0; i < 10; i++ { + fmt.Printf("------ GPU Prove %d ------\n", i+1) + start = time.Now() + proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16), backend.WithZeknoxAcceleration()) + if err != nil { + panic(err) + } + elapsed = time.Since(start) + log.Printf("GPU Prove %d: %d ms", i+1, elapsed.Milliseconds()) + if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { + // panic(err) + fmt.Printf("!!! GPU Verify %d: %s\n", i+1, err) + } + } +} + +func genRandomBytes(size int) ([]byte, error) { + blk := make([]byte, size) + _, err := rand.Read(blk) + if err != nil { + return nil, err + } + return blk, nil +} + +func keccak256(data []byte) (digest [32]byte) { + h := sha3.NewLegacyKeccak256() + h.Write(data) + h.Sum(digest[:0]) + return +} + +func WriteToFile(data io.WriterTo, fileName string) { + file, err := os.Create(fileName) + if err != nil { + panic(err) + } + defer file.Close() + _, err = data.WriteTo(file) + if err != nil { + panic(err) + } +} + +func ReadFromFile(data io.ReaderFrom, fileName string) { + file, err := os.Open(fileName) + if err != nil { + panic(err) + } + defer file.Close() + // Use the ReadFrom method to read the file's content into data. + if _, err := data.ReadFrom(file); err != nil { + panic(err) + } +} + +// faster than readFromFile +func UnsafeReadFromFile(data gnark_io.UnsafeReaderFrom, fileName string) { + file, err := os.Open(fileName) + if err != nil { + panic(err) + } + defer file.Close() + if _, err := data.UnsafeReadFrom(file); err != nil { + panic(err) + } +} \ No newline at end of file diff --git a/examples/p256/p256_test.go b/examples/p256/p256_test.go new file mode 100644 index 0000000000..863773e13b --- /dev/null +++ b/examples/p256/p256_test.go @@ -0,0 +1,17 @@ +package p256 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +func TestP256(t *testing.T) { + assert := test.NewAssert(t) + witnessCircuit := generateWitnessCircuit() + circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witnessCircuit), test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) +} From c481b0a6f23ddff5139f137eb44afbbb6843f201 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:15:55 +0800 Subject: [PATCH 09/62] parallel but verify fail in most cases --- backend/groth16/bn254/zeknox/zeknox.go | 99 +++++++++++++++++--------- examples/main.go | 2 +- 2 files changed, 67 insertions(+), 34 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 780fbb2a23..f81d48fcca 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -251,38 +251,45 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - computeBS1 := func() error { + chBs1Done := make(chan error, 1) + computeBS1 := func() { if err := <-chWireValuesB; err != nil { - return err + chBs1Done <- err + return } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - return err + chBs1Done <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - return nil + chBs1Done <- nil } - computeAR1 := func() error { + chArDone := make(chan error, 1) + computeAR1 := func() { if err := <-chWireValuesA; err != nil { - return err + chArDone <- err + return } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - return err + chArDone <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - return nil + chArDone <- nil } + chKrsDone := make(chan error, 1) var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() error { + computeKRS := func() { // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z // however, having similar lengths for our tasks helps with parallelism @@ -329,37 +336,63 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // CPU // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + // Also, reduce data copy startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - return err + chKrsDone <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) - if err := <-chKrs2Done; err != nil { - return err + n := 3 + for n != 0 { + select { + // wait krs2 + case err := <-chKrs2Done: + if err != nil { + chKrsDone <- err + return + } + krs.AddAssign(&krs2) + // wait ar + case err := <-chArDone: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + // wait bs1 + case err := <-chBs1Done: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) + } + n-- } - krs.AddAssign(&krs2) - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) proof.Krs.FromJacobian(&krs) - return nil + chKrsDone <- nil } - computeBS2 := func() error { + chBs2Done := make(chan error, 1) + computeBS2 := func() { // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac if err := <-chWireValuesB; err != nil { - return err + chBs2Done <- err + return } startBs := time.Now() if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { - return err + chBs2Done <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") @@ -369,27 +402,27 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b Bs.AddMixed(&pk.G2.Beta) proof.Bs.FromJacobian(&Bs) - return nil + chBs2Done <- nil } // wait for FFT to end, as it uses all our CPUs <-chHDone - // schedule our proof part computations - // Sequencial GPU execution - // TODO: see GPU utilization data - if err := computeAR1(); err != nil { - return nil, err - } - if err := computeBS1(); err != nil { - return nil, err - } - if err := computeKRS(); err != nil { + // Parallel GPU execution, memory may hit limit + go computeAR1() + go computeBS1() + go computeBS2() + go computeKRS() + + // wait for all parts of the proof to be computed. + // Krs done means AR1, BS1 are done + if err := <-chKrsDone; err != nil { return nil, err } - if err := computeBS2(); err != nil { + if err := <-chBs2Done; err != nil { return nil, err } + log.Debug().Dur("took", time.Since(start)).Msg("prover done") // Free device memory diff --git a/examples/main.go b/examples/main.go index 6e438873e6..e588a87664 100644 --- a/examples/main.go +++ b/examples/main.go @@ -3,6 +3,6 @@ package main import "github.com/consensys/gnark/examples/p256" func main() { - p256.Groth16Setup("build/") + // p256.Groth16Setup("build/") p256.Groth16Prove("build/") } \ No newline at end of file From 4cb00d6b643a15d80b401e4c08ca82eee772ad4b Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:20:48 +0800 Subject: [PATCH 10/62] fix msm cfg & add input check --- backend/groth16/bn254/zeknox/zeknox.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index f81d48fcca..50f36b07ff 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -411,14 +411,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // Parallel GPU execution, memory may hit limit go computeAR1() go computeBS1() - go computeBS2() go computeKRS() - // wait for all parts of the proof to be computed. - // Krs done means AR1, BS1 are done + // wait krs, ar1, bs1 + // krs done means AR1, BS1 are done if err := <-chKrsDone; err != nil { return nil, err } + // bs2 and bs1 both depend on wireValuesB + computeBS2() if err := <-chBs2Done; err != nil { return nil, err } @@ -518,11 +519,20 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *device.HostOrDeviceSlice[S]) error { + if !points.IsOnDevice() || !scalars.IsOnDevice() { + return fmt.Errorf("MSM: points and scalars must be on device") + } if points.Len() != scalars.Len() { return fmt.Errorf("MSM: len(points) != len(scalars)") } + return nil +} + +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() + cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.FfiAffineSz = 64 @@ -533,9 +543,7 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s } func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - if points.Len() != scalars.Len() { - return fmt.Errorf("MSM: len(points) != len(scalars)") - } + checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true From 4d43c3ebb4f2659f34443d70aa682a396b4be4be Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:10:24 +0800 Subject: [PATCH 11/62] fix parallel GPU proving, use errgroup --- backend/groth16/bn254/zeknox/zeknox.go | 253 +++++++++---------------- 1 file changed, 94 insertions(+), 159 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 50f36b07ff..76dd72af7a 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -176,27 +176,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil, err } - // quotient poly H (witness reduction / FFT part) - var h []fr.Element - chHDone := make(chan struct{}, 1) - go func() { - startH := time.Now() - h = computeH(solution.A, solution.B, solution.C, &pk.Domain) - log.Debug().Dur("took", time.Since(startH)).Msg("computed H") - solution.A = nil - solution.B = nil - solution.C = nil - chHDone <- struct{}{} - }() - // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity - var deviceWireValuesA, deviceWireValuesB *device.HostOrDeviceSlice[fr.Element] - // indicate if the wire values have been copied to the device - chWireValuesA, chWireValuesB := make(chan error, 1), make(chan error, 1) + var wireValuesA, wireValuesB []fr.Element + chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) go func() { - wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) for i, j := 0, 0; j < len(wireValuesA); i++ { if pk.InfinityA[i] { continue @@ -204,16 +190,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesA[j] = wireValues[i] j++ } - chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { - chWireValuesA <- err - return - } - deviceWireValuesA = <-chDeviceValues close(chWireValuesA) }() go func() { - wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) for i, j := 0, 0; j < len(wireValuesB); i++ { if pk.InfinityB[i] { continue @@ -221,12 +201,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesB[j] = wireValues[i] j++ } - chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { - chWireValuesB <- err - return - } - deviceWireValuesB = <-chDeviceValues close(chWireValuesB) }() @@ -251,68 +225,78 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - chBs1Done := make(chan error, 1) - computeBS1 := func() { - if err := <-chWireValuesB; err != nil { - chBs1Done <- err - return + computeBS1 := func() error { + <- chWireValuesB + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err } + wireB = <-chWireB + defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - chBs1Done <- err - return + if err := msmG1(&bs1, pk.G1Device.B, wireB); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - chBs1Done <- nil + return nil } - chArDone := make(chan error, 1) - computeAR1 := func() { - if err := <-chWireValuesA; err != nil { - chArDone <- err - return + computeAR1 := func() error { + <- chWireValuesA + var wireA *device.HostOrDeviceSlice[fr.Element] + chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesA, chWireA); err != nil { + return err } + wireA = <-chWireA + defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - chArDone <- err - return + if err := msmG1(&ar, pk.G1Device.A, wireA); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - chArDone <- nil + return nil } - chKrsDone := make(chan error, 1) - var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() { - // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z - // however, having similar lengths for our tasks helps with parallelism - - var krs, krs2, p1 curve.G1Jac - chKrs2Done := make(chan error, 1) - go func() { - startKrs2 := time.Now() - // Copy h poly to device, since we haven't implemented FFT on device - chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 - if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { - chKrs2Done <- err - return - } - deviceH = <-chDeviceH - if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { - chKrs2Done <- err - return - } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") - close(chKrs2Done) - }() + var krs2 curve.G1Jac + computeKRS2 := func() error { + // quotient poly H (witness reduction / FFT part) + var h []fr.Element + { + startH := time.Now() + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("took", time.Since(startH)).Msg("computed H") + solution.A = nil + solution.B = nil + solution.C = nil + } + // Copy h poly to device, since we haven't implemented FFT on device + var deviceH *device.HostOrDeviceSlice[fr.Element] + chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { + return err + } + deviceH = <-chDeviceH + defer deviceH.Free() + // MSM G1 Krs2 + startKrs2 := time.Now() + if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") + return nil + } + var krs1 curve.G1Jac + computeKRS1 := func() error { // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() @@ -320,81 +304,35 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // original Groth16 witness without pedersen commitment wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - // GPU runtime error - // var deviceWire *device.HostOrDeviceSlice[fr.Element] - // defer deviceWire.Free() - // chDeviceWire := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - // if err := CopyToDevice(wireValuesWithoutCom, chDeviceWire); err != nil { - // chKrsDone <- err - // return - // } - // deviceWire = <-chDeviceWire - // if err := msmG1(&krs, pk.G1Device.K, deviceWire); err != nil { - // chKrsDone <- err - // return - // } - // CPU - // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU - // Also, reduce data copy + // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU, also reduce data copy startKrs := time.Now() - if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - chKrsDone <- err - return + if _, err := krs1.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] - krs.AddMixed(&deltas[2]) - - n := 3 - for n != 0 { - select { - // wait krs2 - case err := <-chKrs2Done: - if err != nil { - chKrsDone <- err - return - } - krs.AddAssign(&krs2) - // wait ar - case err := <-chArDone: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - // wait bs1 - case err := <-chBs1Done: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) - } - n-- - } - - proof.Krs.FromJacobian(&krs) - chKrsDone <- nil + krs1.AddMixed(&deltas[2]) + return nil } - chBs2Done := make(chan error, 1) - computeBS2 := func() { + computeBS2 := func() error { + <-chWireValuesB // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - if err := <-chWireValuesB; err != nil { - chBs2Done <- err - return + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err } + wireB = <-chWireB + defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { - chBs2Done <- err - return + if err := msmG2(&Bs, pk.G2Device.B, wireB); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -402,37 +340,34 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b Bs.AddMixed(&pk.G2.Beta) proof.Bs.FromJacobian(&Bs) - chBs2Done <- nil + return nil } - // wait for FFT to end, as it uses all our CPUs - <-chHDone - - // Parallel GPU execution, memory may hit limit - go computeAR1() - go computeBS1() - go computeKRS() + // Parallel execution, memory may hit limit + g, _ := errgroup.WithContext(context.TODO()) + g.Go(computeAR1) + g.Go(computeBS1) + g.Go(computeKRS1) + g.Go(computeKRS2) + g.Go(computeBS2) - // wait krs, ar1, bs1 - // krs done means AR1, BS1 are done - if err := <-chKrsDone; err != nil { + if err := g.Wait(); err != nil { return nil, err } - // bs2 and bs1 both depend on wireValuesB - computeBS2() - if err := <-chBs2Done; err != nil { - return nil, err + + // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 + { + var p1 curve.G1Jac + krs1.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs1.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs1.AddAssign(&p1) + proof.Krs.FromJacobian(&krs1) } log.Debug().Dur("took", time.Since(start)).Msg("prover done") - // Free device memory - go func() { - deviceWireValuesA.Free() - deviceWireValuesB.Free() - deviceH.Free() - }() - return proof, nil } From 6895eb126f9f24b9002b1050c9fa4f8eaf0243a7 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:29:55 +0800 Subject: [PATCH 12/62] small fix --- backend/groth16/bn254/zeknox/zeknox.go | 2 +- examples/p256/p256.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 76dd72af7a..6c60e081ff 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -310,7 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b if _, err := krs1.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("CPU krs done") // -rs[δ] krs1.AddMixed(&deltas[2]) return nil diff --git a/examples/p256/p256.go b/examples/p256/p256.go index cc580160e6..87c135d0d2 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -27,7 +27,7 @@ import ( "golang.org/x/crypto/sha3" ) -const NumSignatures = 1 +const NumSignatures = 10 var circuitName string From f52ad4ca0c291bc0d88698d8ea92f17e17c13bb9 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:49:08 +0800 Subject: [PATCH 13/62] add doc --- README.md | 30 +++++++++++++++++++++++++++++- go.mod | 4 ++-- go.sum | 6 ++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 351ac156f3..92786166c4 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,34 @@ func main() { ### GPU Support +#### Zeknox Library +Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/cryptography_cuda) + +##### Download prebuilt binary +```sh +sudo cp libblst.a libcryptocuda.a /usr/local/lib/ +``` + +If you want to build from source, see guide in https://github.com/okx/cryptography_cuda + +##### Enjoy GPU +Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` + +Test +```go +assert.ProverSucceeded(&mimcCircuit, &Circuit{ + PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", + Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", + }, test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) +``` + +```sh +go run -tags=zeknox examples/main.go +# (place -tags before the filename) + +go test github.com/consensys/gnark/examples/mimc -tags=prover_checks,zeknox +``` + #### Icicle Library The following schemes and curves support experimental use of Ingonyama's Icicle GPU library for low level zk-SNARK primitives such as MSM, NTT, and polynomial operations: @@ -178,7 +206,7 @@ You can then toggle on or off icicle acceleration by providing the `WithIcicleAc ```go // toggle on proofIci, err := groth16.Prove(ccs, pk, secretWitness, backend.WithIcicleAcceleration()) - + // toggle off proof, err := groth16.Prove(ccs, pk, secretWitness) ``` diff --git a/go.mod b/go.mod index e643586c0c..0770b591b8 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e + github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +// replace github.com/okx/cryptography_cuda => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda diff --git a/go.sum b/go.sum index fc472ffae3..9cca00bced 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,10 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e h1:NT/U7+AJ93s0U4af9I5fEtpE33Etf68wEUif7Q/s1mo= -github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e/go.mod h1:y9SSivg7t0Fs0PZQJ/l2jUhWT67SeEj9XYgz5ysjyEw= +github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868 h1:aPaETd6bRKs2VpM8C9bZrOJprtUMIN2MXQIcCOtovX8= +github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 h1:c3aLIA4Wje6nGEx4XksWkwRI0U6kC9ITXa7ZBp6d5DU= +github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 18e871d8b6f5b4081ecf072ae4f8581cb4415c28 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Tue, 22 Oct 2024 09:55:42 +0800 Subject: [PATCH 14/62] set msm LargeBucketFactor config --- backend/groth16/bn254/zeknox/zeknox.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 6c60e081ff..62417cc78e 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -470,7 +470,7 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) - cfg.FfiAffineSz = 64 + cfg.LargeBucketFactor = 2 if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } From 88237e25f0373248d082fe7339c6232cb5429b11 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:05:20 +0800 Subject: [PATCH 15/62] update msmg1, msmg1 return affine --- backend/groth16/bn254/zeknox/zeknox.go | 6 +++--- examples/p256/p256.go | 2 +- go.mod | 2 +- go.sum | 6 ++---- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 62417cc78e..5513684775 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -471,9 +471,11 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + resAffine := curve.G1Affine{} + if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + res.FromAffine(&resAffine) return nil } @@ -484,8 +486,6 @@ func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], s cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - // TODO: MSM_G2 should return Jacobian - // https://github.com/okx/cryptography_cuda/issues/90 resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 87c135d0d2..cc580160e6 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -27,7 +27,7 @@ import ( "golang.org/x/crypto/sha3" ) -const NumSignatures = 10 +const NumSignatures = 1 var circuitName string diff --git a/go.mod b/go.mod index 0770b591b8..6596bce5d9 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 + github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 9cca00bced..8da0c561d1 100644 --- a/go.sum +++ b/go.sum @@ -230,10 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868 h1:aPaETd6bRKs2VpM8C9bZrOJprtUMIN2MXQIcCOtovX8= -github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= -github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 h1:c3aLIA4Wje6nGEx4XksWkwRI0U6kC9ITXa7ZBp6d5DU= -github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 h1:HgiJDIO/n8DTRCTRaw7CYm042Ieyo00O7wD90ZUteO0= +github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 26ba8db8008c89a5416fc2eaca058d8038e922a0 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:59:22 +0800 Subject: [PATCH 16/62] fix cuda int --- backend/groth16/bn254/zeknox/zeknox.go | 114 ++++++++++++++++++++----- 1 file changed, 94 insertions(+), 20 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 5513684775..af085f7106 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "runtime" + "sync/atomic" "time" "unsafe" @@ -30,6 +31,11 @@ import ( "golang.org/x/sync/errgroup" ) +var g2_point_b_mont int32 = 0 +var g1_point_b_mont int32 = 0 +var g1_point_a_mont int32 = 0 +var g1_point_z_mont int32 = 0 + const HasZeknox = true // Use single GPU @@ -226,7 +232,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { - <- chWireValuesB + <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesB, chWireB); err != nil { @@ -235,7 +241,17 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, pk.G1Device.B, wireB); err != nil { + + val := atomic.LoadInt32(&g1_point_b_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_b_mont, 1) + mont = true + } + + if err := msmG1(&bs1, pk.G1Device.B, wireB, mont); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -246,7 +262,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { - <- chWireValuesA + <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesA, chWireA); err != nil { @@ -255,9 +271,20 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, pk.G1Device.A, wireA); err != nil { + + val := atomic.LoadInt32(&g1_point_a_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_a_mont, 1) + mont = true + } + + if err := msmG1(&ar, pk.G1Device.A, wireA, mont); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) @@ -288,9 +315,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + + val := atomic.LoadInt32(&g1_point_z_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_z_mont, 1) + mont = true + } + if err := msmG1(&krs2, pk.G1Device.Z, deviceH, mont); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") return nil } @@ -329,9 +366,29 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, pk.G2Device.B, wireB); err != nil { + // scalar := onHost(wireValuesB[:]) + // point := onHost(pk.G2.B[:]) + // if err := msmG2(&Bs, &point, &scalar); err != nil { + // return err + // } + + val := atomic.LoadInt32(&g2_point_b_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g2_point_b_mont, 1) + mont = true + } + + if err := msmG2(&Bs, pk.G2Device.B, wireB, mont); err != nil { + return err + } + + if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: 16}); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) @@ -344,16 +401,21 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // Parallel execution, memory may hit limit - g, _ := errgroup.WithContext(context.TODO()) - g.Go(computeAR1) - g.Go(computeBS1) - g.Go(computeKRS1) - g.Go(computeKRS2) - g.Go(computeBS2) - - if err := g.Wait(); err != nil { - return nil, err - } + // g, _ := errgroup.WithContext(context.TODO()) + // g.Go(computeAR1) + computeAR1() + // g.Go(computeBS1) + computeBS1() + // g.Go(computeKRS1) + computeKRS1() + // g.Go(computeKRS2) + computeKRS2() + + // if err := g.Wait(); err != nil { + // return nil, err + // } + + computeBS2() // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { @@ -371,6 +433,12 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return proof, nil } +func onHost[T any](hostData []T) device.HostOrDeviceSlice[T] { + deviceSlice := device.NewEmpty[T]() + deviceSlice.OnHost(hostData) + return *deviceSlice +} + // if len(toRemove) == 0, returns slice // // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex @@ -464,11 +532,14 @@ func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *devi return nil } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element], input_point_in_mont bool) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = true + + cfg.AreInputPointInMont = input_point_in_mont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} @@ -479,11 +550,14 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s return nil } -func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element], mont bool) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = true + + cfg.AreInputPointInMont = mont + cfg.AreOutputPointInMont = true + cfg.AreInputScalarInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} From d99fc59e400640aa7034bffc87e4f62095a85ebf Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:20:06 +0800 Subject: [PATCH 17/62] generate witness in every prove --- examples/p256/p256.go | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/p256/p256.go b/examples/p256/p256.go index cc580160e6..45a445117a 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -141,18 +141,12 @@ func Groth16Prove(fileDir string) { // proveStart := time.Now() // Witness generation start := time.Now() - witnessData, err := generateWitness() - if err != nil { - panic(err) - } - elapsed := time.Since(start) - log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) // Read files start = time.Now() r1cs := groth16.NewCS(ecc.BN254) ReadFromFile(r1cs, fileDir+circuitName+".r1cs") - elapsed = time.Since(start) + elapsed := time.Since(start) log.Printf("Read r1cs: %d ms", elapsed.Milliseconds()) start = time.Now() @@ -163,16 +157,19 @@ func Groth16Prove(fileDir string) { log.Printf("Read zkey: %d ms", elapsed.Milliseconds()) // Proof generation & verification - publicWitness, err := witnessData.Public() - if err != nil { - panic(err) - } vk := groth16.NewVerifyingKey(ecc.BN254) ReadFromFile(vk, fileDir+circuitName+".vkey") // CPU for i := 0; i < 1; i++ { fmt.Printf("------ CPU Prove %d ------", i+1) + witnessData, err := generateWitness() + if err != nil { + panic(err) + } + elapsed := time.Since(start) + log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) + start = time.Now() proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16)) if err != nil { @@ -180,6 +177,10 @@ func Groth16Prove(fileDir string) { } elapsed = time.Since(start) log.Printf("CPU Prove: %d ms", elapsed.Milliseconds()) + publicWitness, err := witnessData.Public() + if err != nil { + panic(err) + } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { panic(err) } @@ -188,6 +189,13 @@ func Groth16Prove(fileDir string) { // GPU for i := 0; i < 10; i++ { fmt.Printf("------ GPU Prove %d ------\n", i+1) + witnessData, err := generateWitness() + if err != nil { + panic(err) + } + elapsed := time.Since(start) + log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) + start = time.Now() proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16), backend.WithZeknoxAcceleration()) if err != nil { @@ -195,6 +203,10 @@ func Groth16Prove(fileDir string) { } elapsed = time.Since(start) log.Printf("GPU Prove %d: %d ms", i+1, elapsed.Milliseconds()) + publicWitness, err := witnessData.Public() + if err != nil { + panic(err) + } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { // panic(err) fmt.Printf("!!! GPU Verify %d: %s\n", i+1, err) @@ -252,4 +264,4 @@ func UnsafeReadFromFile(data gnark_io.UnsafeReaderFrom, fileName string) { if _, err := data.UnsafeReadFrom(file); err != nil { panic(err) } -} \ No newline at end of file +} From 6a13e9b8d5681c000ba0e7257342f72e67a5ae2c Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:20:06 +0800 Subject: [PATCH 18/62] delete unused deviceInfo --- backend/groth16/bn254/zeknox/provingkey.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index f54a97a669..4f66ec9661 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -1,10 +1,8 @@ package zeknox_bn254 import ( - "unsafe" - - groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" "github.com/consensys/gnark-crypto/ecc/bn254" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" cs "github.com/consensys/gnark/constraint/bn254" "github.com/okx/cryptography_cuda/wrappers/go/device" ) @@ -13,14 +11,9 @@ type deviceInfo struct { G1Device struct { A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] } - DomainDevice struct { - Twiddles, TwiddlesInv unsafe.Pointer - CosetTable, CosetTableInv unsafe.Pointer - } G2Device struct { B *device.HostOrDeviceSlice[bn254.G2Affine] } - DenDevice unsafe.Pointer InfinityPointIndicesK []int } From 59bccd93760439530f6bbcf4f4ef9872835d9456 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 19/62] deviceInfo each points store ArePointsInMont --- backend/groth16/bn254/zeknox/provingkey.go | 26 ++++- backend/groth16/bn254/zeknox/zeknox.go | 105 +++++++-------------- examples/main.go | 2 +- examples/p256/p256.go | 10 +- 4 files changed, 64 insertions(+), 79 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index 4f66ec9661..fc58d89a00 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -9,10 +9,10 @@ import ( type deviceInfo struct { G1Device struct { - A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] + A, B, K, Z DevicePoints[bn254.G1Affine] } G2Device struct { - B *device.HostOrDeviceSlice[bn254.G2Affine] + B DevicePoints[bn254.G2Affine] } InfinityPointIndicesK []int } @@ -22,6 +22,14 @@ type ProvingKey struct { *deviceInfo } +type DevicePoints[T bn254.G1Affine | bn254.G2Affine] struct { + *device.HostOrDeviceSlice[T] + // Gnark points are in Montgomery form + // After 1 GPU MSM, points in GPU are converted to affine form + // Pass it to MSM config + ArePointsInMont bool +} + func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) } @@ -29,3 +37,17 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) } + +// You should call this method to free the GPU memory +// +// pk := groth16.NewProvingKey(ecc.BN254) +// defer pk.(*zeknox_bn254.ProvingKey).Free() +func (pk *ProvingKey) Free() { + if pk.deviceInfo != nil { + pk.deviceInfo.G1Device.A.Free() + pk.deviceInfo.G1Device.B.Free() + pk.deviceInfo.G1Device.K.Free() + pk.deviceInfo.G1Device.Z.Free() + pk.deviceInfo.G2Device.B.Free() + } +} \ No newline at end of file diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index af085f7106..9bf292d9f6 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -7,7 +7,6 @@ import ( "fmt" "math/big" "runtime" - "sync/atomic" "time" "unsafe" @@ -46,7 +45,6 @@ func (pk *ProvingKey) setupDevicePointers() error { return nil } pk.deviceInfo = &deviceInfo{} - // TODO: setup FFT // MSM G1 & G2 Device Setup g, _ := errgroup.WithContext(context.TODO()) @@ -84,11 +82,26 @@ func (pk *ProvingKey) setupDevicePointers() error { return err } // if no error, store device pointers in pk - pk.G1Device.A = <-deviceA - pk.G1Device.B = <-deviceG1B - pk.G1Device.K = <-deviceK - pk.G1Device.Z = <-deviceZ - pk.G2Device.B = <-deviceG2B + pk.G1Device.A = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceA, + ArePointsInMont: true, + } + pk.G1Device.B = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceG1B, + ArePointsInMont: true, + } + pk.G1Device.K = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceK, + ArePointsInMont: true, + } + pk.G1Device.Z = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceZ, + ArePointsInMont: true, + } + pk.G2Device.B = DevicePoints[curve.G2Affine]{ + HostOrDeviceSlice: <-deviceG2B, + ArePointsInMont: true, + } return nil } @@ -232,6 +245,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { + <-chWireValuesB <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -241,17 +255,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - - val := atomic.LoadInt32(&g1_point_b_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_b_mont, 1) - mont = true - } - - if err := msmG1(&bs1, pk.G1Device.B, wireB, mont); err != nil { + if err := msmG1(&bs1, &pk.G1Device.B, wireB); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -262,6 +266,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { + <-chWireValuesA <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -271,17 +276,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - - val := atomic.LoadInt32(&g1_point_a_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_a_mont, 1) - mont = true - } - - if err := msmG1(&ar, pk.G1Device.A, wireA, mont); err != nil { + if err := msmG1(&ar, &pk.G1Device.A, wireA); err != nil { return err } @@ -315,16 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - - val := atomic.LoadInt32(&g1_point_z_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_z_mont, 1) - mont = true - } - if err := msmG1(&krs2, pk.G1Device.Z, deviceH, mont); err != nil { + if err := msmG1(&krs2, &pk.G1Device.Z, deviceH); err != nil { return err } @@ -366,26 +352,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - // scalar := onHost(wireValuesB[:]) - // point := onHost(pk.G2.B[:]) - // if err := msmG2(&Bs, &point, &scalar); err != nil { - // return err - // } - - val := atomic.LoadInt32(&g2_point_b_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g2_point_b_mont, 1) - mont = true - } - - if err := msmG2(&Bs, pk.G2Device.B, wireB, mont); err != nil { - return err - } - - if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: 16}); err != nil { + if err := msmG2(&Bs, &pk.G2Device.B, wireB); err != nil { return err } @@ -522,7 +489,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *device.HostOrDeviceSlice[S]) error { +func checkMsmInputs[P curve.G1Affine | curve.G2Affine](points *DevicePoints[P], scalars *device.HostOrDeviceSlice[fr.Element]) error { if !points.IsOnDevice() || !scalars.IsOnDevice() { return fmt.Errorf("MSM: points and scalars must be on device") } @@ -532,38 +499,36 @@ func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *devi return nil } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element], input_point_in_mont bool) error { +func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - - cfg.AreInputPointInMont = input_point_in_mont - cfg.AreInputScalarInMont = true - cfg.AreOutputPointInMont = true + cfg.ArePointsInMont = points.ArePointsInMont cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + // After 1 GPU MSM, points in GPU are converted to affine form + points.ArePointsInMont = false res.FromAffine(&resAffine) return nil } -func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element], mont bool) error { +func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - - cfg.AreInputPointInMont = mont - cfg.AreOutputPointInMont = true - cfg.AreInputScalarInMont = true + cfg.ArePointsInMont = points.ArePointsInMont cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + // After 1 GPU MSM, points in GPU are converted to affine form + points.ArePointsInMont = false res.FromAffine(&resAffine) return nil } diff --git a/examples/main.go b/examples/main.go index e588a87664..a27d74c87c 100644 --- a/examples/main.go +++ b/examples/main.go @@ -5,4 +5,4 @@ import "github.com/consensys/gnark/examples/p256" func main() { // p256.Groth16Setup("build/") p256.Groth16Prove("build/") -} \ No newline at end of file +} diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 45a445117a..42be79c039 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -15,6 +15,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" "github.com/consensys/gnark/backend/solidity" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" @@ -138,20 +139,17 @@ func Groth16Setup(fileDir string) { } func Groth16Prove(fileDir string) { - // proveStart := time.Now() - // Witness generation + // Read r1cs start := time.Now() - - // Read files - start = time.Now() r1cs := groth16.NewCS(ecc.BN254) ReadFromFile(r1cs, fileDir+circuitName+".r1cs") elapsed := time.Since(start) log.Printf("Read r1cs: %d ms", elapsed.Milliseconds()) + // read zkey start = time.Now() pk := groth16.NewProvingKey(ecc.BN254) - + defer pk.(*zeknox_bn254.ProvingKey).Free() UnsafeReadFromFile(pk, fileDir+circuitName+".zkey") elapsed = time.Since(start) log.Printf("Read zkey: %d ms", elapsed.Milliseconds()) From fd2ace4cc7b0e952c4ee2faaff4837ebda5e60e2 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 20/62] update cuda library, verify GPU proof success! --- backend/groth16/bn254/zeknox/provingkey.go | 4 ++-- backend/groth16/bn254/zeknox/zeknox.go | 21 ++++++++++++--------- go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index fc58d89a00..854dad9d78 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -27,7 +27,7 @@ type DevicePoints[T bn254.G1Affine | bn254.G2Affine] struct { // Gnark points are in Montgomery form // After 1 GPU MSM, points in GPU are converted to affine form // Pass it to MSM config - ArePointsInMont bool + Mont bool } func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { @@ -50,4 +50,4 @@ func (pk *ProvingKey) Free() { pk.deviceInfo.G1Device.Z.Free() pk.deviceInfo.G2Device.B.Free() } -} \ No newline at end of file +} diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 9bf292d9f6..8fc4e7740d 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -84,23 +84,23 @@ func (pk *ProvingKey) setupDevicePointers() error { // if no error, store device pointers in pk pk.G1Device.A = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceA, - ArePointsInMont: true, + Mont: true, } pk.G1Device.B = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceG1B, - ArePointsInMont: true, + Mont: true, } pk.G1Device.K = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceK, - ArePointsInMont: true, + Mont: true, } pk.G1Device.Z = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceZ, - ArePointsInMont: true, + Mont: true, } pk.G2Device.B = DevicePoints[curve.G2Affine]{ HostOrDeviceSlice: <-deviceG2B, - ArePointsInMont: true, + Mont: true, } return nil @@ -503,7 +503,8 @@ func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *devi checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = points.ArePointsInMont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} @@ -511,7 +512,7 @@ func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *devi return err } // After 1 GPU MSM, points in GPU are converted to affine form - points.ArePointsInMont = false + points.Mont = false res.FromAffine(&resAffine) return nil } @@ -520,7 +521,9 @@ func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *devi checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = points.ArePointsInMont + cfg.AreInputPointInMont = points.Mont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} @@ -528,7 +531,7 @@ func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *devi return err } // After 1 GPU MSM, points in GPU are converted to affine form - points.ArePointsInMont = false + points.Mont = false res.FromAffine(&resAffine) return nil } diff --git a/go.mod b/go.mod index 6596bce5d9..65b07a707e 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 + github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 8da0c561d1..dc772d6de9 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 h1:HgiJDIO/n8DTRCTRaw7CYm042Ieyo00O7wD90ZUteO0= -github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 h1:rItWN8zYu0DqhyQvKfzmYVRUJmIJonqNP9N8WNhYoAo= +github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527/go.mod h1:Azb8uIJJqdXIq5np5A/RK2ga1sW1949bKFLg3yZCZjI= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From a2694e02c868a0d36928240d50da848245f5bda5 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 21/62] refactor msm, 1 msm func for both G1 and G2 --- backend/groth16/bn254/zeknox/zeknox.go | 75 ++++++++++++++------------ 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 8fc4e7740d..817b24cf63 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -255,7 +255,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, &pk.G1Device.B, wireB); err != nil { + if err := gpuMsm(&bs1, &pk.G1Device.B, wireB); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -276,7 +276,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, &pk.G1Device.A, wireA); err != nil { + if err := gpuMsm(&ar, &pk.G1Device.A, wireA); err != nil { return err } @@ -310,7 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - if err := msmG1(&krs2, &pk.G1Device.Z, deviceH); err != nil { + if err := gpuMsm(&krs2, &pk.G1Device.Z, deviceH); err != nil { return err } @@ -352,7 +352,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, &pk.G2Device.B, wireB); err != nil { + if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { return err } @@ -489,50 +489,57 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func checkMsmInputs[P curve.G1Affine | curve.G2Affine](points *DevicePoints[P], scalars *device.HostOrDeviceSlice[fr.Element]) error { +// GPU Msm for either G1 or G2 points +func gpuMsm[R curve.G1Jac | curve.G2Jac, P curve.G1Affine | curve.G2Affine]( + res *R, + points *DevicePoints[P], + scalars *device.HostOrDeviceSlice[fr.Element], +) error { + // Check inputs if !points.IsOnDevice() || !scalars.IsOnDevice() { - return fmt.Errorf("MSM: points and scalars must be on device") + panic("points and scalars must be on device") } if points.Len() != scalars.Len() { - return fmt.Errorf("MSM: len(points) != len(scalars)") + panic("points and scalars should be in the same length") } - return nil -} -func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - checkMsmInputs(points, scalars) + // Setup MSM config cfg := msm.DefaultMSMConfig() + cfg.AreInputPointInMont = points.Mont cfg.AreInputsOnDevice = true cfg.AreInputScalarInMont = true cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - resAffine := curve.G1Affine{} - if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { - return err - } - // After 1 GPU MSM, points in GPU are converted to affine form - points.Mont = false - res.FromAffine(&resAffine) - return nil -} -func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - checkMsmInputs(points, scalars) - cfg := msm.DefaultMSMConfig() - cfg.AreInputsOnDevice = true - cfg.AreInputPointInMont = points.Mont - cfg.AreInputScalarInMont = true - cfg.AreOutputPointInMont = true - cfg.Npoints = uint32(points.Len()) - cfg.LargeBucketFactor = 2 - resAffine := curve.G2Affine{} - if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { - return err + switch any(points).(type) { + case *DevicePoints[curve.G1Affine]: + resAffine := curve.G1Affine{} + if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G1Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G1Jac") + } + case *DevicePoints[curve.G2Affine]: + resAffine := curve.G2Affine{} + if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G2Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G2Jac") + } + default: + panic("invalid points type") } - // After 1 GPU MSM, points in GPU are converted to affine form + + // After GPU MSM, points in GPU are converted to affine form points.Mont = false - res.FromAffine(&resAffine) + return nil } From 8328573370ad359191c0060a7032c7d342149e58 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:42:55 +0800 Subject: [PATCH 22/62] parallel msm, sometimes verify fail --- backend/groth16/bn254/zeknox/zeknox.go | 41 ++++++++------------------ examples/p256/p256.go | 2 +- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 817b24cf63..2a3bfa89bb 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -30,11 +30,6 @@ import ( "golang.org/x/sync/errgroup" ) -var g2_point_b_mont int32 = 0 -var g1_point_b_mont int32 = 0 -var g1_point_a_mont int32 = 0 -var g1_point_z_mont int32 = 0 - const HasZeknox = true // Use single GPU @@ -245,7 +240,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { - <-chWireValuesB <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -266,7 +260,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { - <-chWireValuesA <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -368,19 +361,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // Parallel execution, memory may hit limit - // g, _ := errgroup.WithContext(context.TODO()) - // g.Go(computeAR1) - computeAR1() - // g.Go(computeBS1) - computeBS1() - // g.Go(computeKRS1) - computeKRS1() - // g.Go(computeKRS2) - computeKRS2() - - // if err := g.Wait(); err != nil { - // return nil, err - // } + g, _ := errgroup.WithContext(context.TODO()) + g.Go(computeAR1) + g.Go(computeBS1) + g.Go(computeKRS1) + g.Go(computeKRS2) + + if err := g.Wait(); err != nil { + return nil, err + } computeBS2() @@ -400,12 +389,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return proof, nil } -func onHost[T any](hostData []T) device.HostOrDeviceSlice[T] { - deviceSlice := device.NewEmpty[T]() - deviceSlice.OnHost(hostData) - return *deviceSlice -} - // if len(toRemove) == 0, returns slice // // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex @@ -519,10 +502,10 @@ func gpuMsm[R curve.G1Jac | curve.G2Jac, P curve.G1Affine | curve.G2Affine]( return err } if r, ok := any(res).(*curve.G1Jac); ok { - r.FromAffine(&resAffine) - } else { + r.FromAffine(&resAffine) + } else { panic("res type should be *curve.G1Jac") - } + } case *DevicePoints[curve.G2Affine]: resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 42be79c039..d365367947 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -206,8 +206,8 @@ func Groth16Prove(fileDir string) { panic(err) } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { + fmt.Printf("\n!!! GPU Verify %d: %s\n\n", i+1, err) // panic(err) - fmt.Printf("!!! GPU Verify %d: %s\n", i+1, err) } } } From 8c49c6c4f6703153ca865af088eed121a18c614c Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:55:56 +0800 Subject: [PATCH 23/62] parallel + copy point every time --- backend/groth16/bn254/zeknox/zeknox.go | 13 +++++++++++-- examples/p256/p256.go | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 2a3bfa89bb..760a51a697 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -114,6 +114,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + if pk.deviceInfo != nil { + pk.Free() + pk.deviceInfo = nil + } if pk.deviceInfo == nil { start := time.Now() if err := pk.setupDevicePointers(); err != nil { @@ -366,12 +370,17 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b g.Go(computeBS1) g.Go(computeKRS1) g.Go(computeKRS2) - + g.Go(computeBS2) if err := g.Wait(); err != nil { return nil, err } - computeBS2() + // Serial execution + // computeAR1() + // computeBS1() + // computeKRS1() + // computeKRS2() + // computeBS2() // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { diff --git a/examples/p256/p256.go b/examples/p256/p256.go index d365367947..b85a6cffa8 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -185,7 +185,7 @@ func Groth16Prove(fileDir string) { } // GPU - for i := 0; i < 10; i++ { + for i := 0; i < 20; i++ { fmt.Printf("------ GPU Prove %d ------\n", i+1) witnessData, err := generateWitness() if err != nil { @@ -207,7 +207,7 @@ func Groth16Prove(fileDir string) { } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { fmt.Printf("\n!!! GPU Verify %d: %s\n\n", i+1, err) - // panic(err) + panic(err) } } } From 8d84de24282542c06a8317f2f4b0a147901bc311 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:00:03 +0800 Subject: [PATCH 24/62] serial GPU msm, always success --- backend/groth16/bn254/zeknox/zeknox.go | 23 +++++++---------------- examples/p256/p256.go | 14 ++------------ 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 760a51a697..2195c231c4 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -114,10 +114,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() - if pk.deviceInfo != nil { - pk.Free() - pk.deviceInfo = nil - } if pk.deviceInfo == nil { start := time.Now() if err := pk.setupDevicePointers(); err != nil { @@ -364,24 +360,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil } - // Parallel execution, memory may hit limit g, _ := errgroup.WithContext(context.TODO()) - g.Go(computeAR1) - g.Go(computeBS1) + // CPU MSM g.Go(computeKRS1) - g.Go(computeKRS2) - g.Go(computeBS2) + + // Serial GPU MSM + computeAR1() + computeBS1() + computeKRS2() + computeBS2() if err := g.Wait(); err != nil { return nil, err } - // Serial execution - // computeAR1() - // computeBS1() - // computeKRS1() - // computeKRS2() - // computeBS2() - // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { var p1 curve.G1Jac diff --git a/examples/p256/p256.go b/examples/p256/p256.go index b85a6cffa8..87b1fafbcf 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -28,7 +28,7 @@ import ( "golang.org/x/crypto/sha3" ) -const NumSignatures = 1 +const NumSignatures = 10 var circuitName string @@ -165,16 +165,11 @@ func Groth16Prove(fileDir string) { if err != nil { panic(err) } - elapsed := time.Since(start) - log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) - start = time.Now() proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16)) if err != nil { panic(err) } - elapsed = time.Since(start) - log.Printf("CPU Prove: %d ms", elapsed.Milliseconds()) publicWitness, err := witnessData.Public() if err != nil { panic(err) @@ -191,23 +186,18 @@ func Groth16Prove(fileDir string) { if err != nil { panic(err) } - elapsed := time.Since(start) - log.Printf("Witness Generation: %d ms", elapsed.Milliseconds()) - start = time.Now() proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16), backend.WithZeknoxAcceleration()) if err != nil { panic(err) } - elapsed = time.Since(start) - log.Printf("GPU Prove %d: %d ms", i+1, elapsed.Milliseconds()) publicWitness, err := witnessData.Public() if err != nil { panic(err) } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { fmt.Printf("\n!!! GPU Verify %d: %s\n\n", i+1, err) - panic(err) + // panic(err) } } } From 554277a242191fcee1a9124bf12a1b0b50673f97 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 5 Nov 2024 18:56:20 +0800 Subject: [PATCH 25/62] small improvement in zeknox prover --- backend/groth16/bn254/zeknox/provingkey.go | 2 +- backend/groth16/bn254/zeknox/zeknox.go | 92 ++++++++++++++++++---- examples/p256/p256.go | 2 +- go.mod | 4 +- go.sum | 2 - 5 files changed, 80 insertions(+), 22 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index 854dad9d78..fd6966e294 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -4,7 +4,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" cs "github.com/consensys/gnark/constraint/bn254" - "github.com/okx/cryptography_cuda/wrappers/go/device" + "github.com/okx/zeknox/wrappers/go/device" ) type deviceInfo struct { diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 2195c231c4..6b046cdd42 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -25,8 +25,8 @@ import ( fcs "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" - "github.com/okx/cryptography_cuda/wrappers/go/device" - "github.com/okx/cryptography_cuda/wrappers/go/msm" + "github.com/okx/zeknox/wrappers/go/device" + "github.com/okx/zeknox/wrappers/go/msm" "golang.org/x/sync/errgroup" ) @@ -62,6 +62,7 @@ func (pk *ProvingKey) setupDevicePointers() error { } deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) + // g.Go(func() error { return CopyToDevice(pk.G1.K, deviceK) }) // G1.Z deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) @@ -312,7 +313,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } var krs1 curve.G1Jac - computeKRS1 := func() error { + + computeKRS1_CPU := func() error { // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() @@ -332,24 +334,80 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil } - computeBS2 := func() error { - <-chWireValuesB + /* + computeKRS1 := func() error { + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + // original Groth16 witness without pedersen commitment + wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + + var deviceWireValuesWithoutCom *device.HostOrDeviceSlice[fr.Element] + chDeviceW := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeW := len(wireValuesWithoutCom) + // copy to GPU + if err := CopyToDevice(wireValuesWithoutCom[:sizeW], chDeviceW); err != nil { + return err + } + deviceWireValuesWithoutCom = <-chDeviceW + defer deviceWireValuesWithoutCom.Free() + // on GPU + startKrs := time.Now() + if err := gpuMsm(&krs1, &pk.G1Device.K, deviceWireValuesWithoutCom); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("GPU krs1 done") + // -rs[δ] + krs1.AddMixed(&deltas[2]) + return nil + } + */ + /* + computeBS2 := func() error { + <-chWireValuesB + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac + + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err + } + wireB = <-chWireB + defer wireB.Free() + startBs := time.Now() + if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { + return err + } + + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") + + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) + + proof.Bs.FromJacobian(&Bs) + return nil + } + */ + + computeBS2_CPU := func() error { // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - var wireB *device.HostOrDeviceSlice[fr.Element] - chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesB, chWireB); err != nil { - return err + nbTasks := runtime.NumCPU() / 2 + if nbTasks <= 16 { + // if we don't have a lot of CPUs, this may artificially split the MSM + nbTasks *= 2 } - wireB = <-chWireB - defer wireB.Free() + <-chWireValuesB startBs := time.Now() - if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { + if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { return err } - - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %d took", len(wireValuesB)), time.Since(startBs)).Msg("Bs.MultiExp done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -362,13 +420,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b g, _ := errgroup.WithContext(context.TODO()) // CPU MSM - g.Go(computeKRS1) + g.Go(computeKRS1_CPU) + g.Go(computeBS2_CPU) // Serial GPU MSM computeAR1() computeBS1() + // computeKRS1() computeKRS2() - computeBS2() + // computeBS2() if err := g.Wait(); err != nil { return nil, err } diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 87b1fafbcf..1c89c4edff 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -180,7 +180,7 @@ func Groth16Prove(fileDir string) { } // GPU - for i := 0; i < 20; i++ { + for i := 0; i < 1; i++ { fmt.Printf("------ GPU Prove %d ------\n", i+1) witnessData, err := generateWitness() if err != nil { diff --git a/go.mod b/go.mod index 65b07a707e..b64287917e 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 + github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -// replace github.com/okx/cryptography_cuda => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda +// replace github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 => /home/ubuntu/git/cryptography_cuda diff --git a/go.sum b/go.sum index dc772d6de9..81dc4a651b 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,6 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 h1:rItWN8zYu0DqhyQvKfzmYVRUJmIJonqNP9N8WNhYoAo= -github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527/go.mod h1:Azb8uIJJqdXIq5np5A/RK2ga1sW1949bKFLg3yZCZjI= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 48475b8d165a91646c35799001d9f240dbe01e10 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 20 Nov 2024 13:06:16 +0800 Subject: [PATCH 26/62] improve p256 example --- examples/main.go | 10 ++++++++-- examples/p256/p256.go | 2 +- go.mod | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/main.go b/examples/main.go index a27d74c87c..1326e8e7a1 100644 --- a/examples/main.go +++ b/examples/main.go @@ -1,8 +1,14 @@ package main -import "github.com/consensys/gnark/examples/p256" +import ( + "os" + + "github.com/consensys/gnark/examples/p256" +) func main() { - // p256.Groth16Setup("build/") + if _, err := os.Stat("build/"); os.IsNotExist(err) { + p256.Groth16Setup("build/") + } p256.Groth16Prove("build/") } diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 1c89c4edff..a6a34baac7 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -196,7 +196,7 @@ func Groth16Prove(fileDir string) { panic(err) } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { - fmt.Printf("\n!!! GPU Verify %d: %s\n\n", i+1, err) + fmt.Printf("\nError in GPU Verify %d: %s\n\n", i+1, err) // panic(err) } } diff --git a/go.mod b/go.mod index b64287917e..b7b7796b50 100644 --- a/go.mod +++ b/go.mod @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -// replace github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 => /home/ubuntu/git/cryptography_cuda +// replace github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 => /home/ubuntu/git/zeknox From 48d719dbc04de79ab90c13a489858d7b444b8c2f Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 21 Nov 2024 10:44:50 +0800 Subject: [PATCH 27/62] update zeknox to v1.0.0 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b7b7796b50..bc96965316 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 + github.com/okx/zeknox v1.0.0 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 81dc4a651b..62b4183cd9 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/okx/zeknox v1.0.0 h1:W/nZnaBIQjB5LHK2DsVdlL5v7NHP4OBPLziZ6gh/14U= +github.com/okx/zeknox v1.0.0/go.mod h1:zlHemJhkN7W22xWWtANF66oPdzUJYT1frlkdSZhLQbc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 2e0ef674534ff9003a32d7654b25fdc54617ac3f Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:13:18 +0800 Subject: [PATCH 28/62] init zeknox GPU acceleration --- backend/backend.go | 13 ++++ backend/groth16/bn254/zeknox/doc.go | 2 + backend/groth16/bn254/zeknox/marshal_test.go | 67 ++++++++++++++++++++ backend/groth16/bn254/zeknox/nozeknox.go | 18 ++++++ backend/groth16/bn254/zeknox/provingkey.go | 36 +++++++++++ backend/groth16/groth16.go | 21 ++++++ 6 files changed, 157 insertions(+) create mode 100644 backend/groth16/bn254/zeknox/doc.go create mode 100644 backend/groth16/bn254/zeknox/marshal_test.go create mode 100644 backend/groth16/bn254/zeknox/nozeknox.go create mode 100644 backend/groth16/bn254/zeknox/provingkey.go diff --git a/backend/backend.go b/backend/backend.go index 7c427e5825..acd89a0585 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -121,6 +121,19 @@ func WithProverKZGFoldingHashFunction(hFunc hash.Hash) ProverOption { } } +// WithZeknoxAcceleration requests to use [ZEKNOX] GPU proving backend for the +// prover. This option requires that the program is compiled with `zeknox` build +// tag and the ZEKNOX dependencies are properly installed. See [ZEKNOX] for +// installation description. +// +// [ZEKNOX]: https://github.com/okx/cryptography_cuda +func WithZeknoxAcceleration() ProverOption { + return func(pc *ProverConfig) error { + pc.Accelerator = "zeknox" + return nil + } +} + // WithIcicleAcceleration requests to use [ICICLE] GPU proving backend for the // prover. This option requires that the program is compiled with `icicle` build // tag and the ICICLE dependencies are properly installed. See [ICICLE] for diff --git a/backend/groth16/bn254/zeknox/doc.go b/backend/groth16/bn254/zeknox/doc.go new file mode 100644 index 0000000000..2200b550c7 --- /dev/null +++ b/backend/groth16/bn254/zeknox/doc.go @@ -0,0 +1,2 @@ +// Package zeknox_bn254 implements zeknox acceleration for BN254 Groth16 backend. +package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/marshal_test.go b/backend/groth16/bn254/zeknox/marshal_test.go new file mode 100644 index 0000000000..5e9b2aeaea --- /dev/null +++ b/backend/groth16/bn254/zeknox/marshal_test.go @@ -0,0 +1,67 @@ +package zeknox_bn254_test + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" + cs_bn254 "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type circuit struct { + A, B frontend.Variable `gnark:",public"` + Res frontend.Variable +} + +func (c *circuit) Define(api frontend.API) error { + api.AssertIsEqual(api.Mul(c.A, c.B), c.Res) + return nil +} + +func TestMarshalNative(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + nativePK := groth16_bn254.ProvingKey{} + nativeVK := groth16_bn254.VerifyingKey{} + err = groth16_bn254.Setup(tCcs, &nativePK, &nativeVK) + assert.NoError(err) + + pk := groth16.NewProvingKey(ecc.BN254) + buf := new(bytes.Buffer) + _, err = nativePK.WriteTo(buf) + assert.NoError(err) + _, err = pk.ReadFrom(buf) + assert.NoError(err) + if pk.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} + +func TestMarshalZeknox(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit{}) + assert.NoError(err) + tCcs := ccs.(*cs_bn254.R1CS) + zePK := zeknox_bn254.ProvingKey{} + VK := groth16_bn254.VerifyingKey{} + err = zeknox_bn254.Setup(tCcs, &zePK, &VK) + assert.NoError(err) + + nativePK := groth16_bn254.ProvingKey{} + buf := new(bytes.Buffer) + _, err = zePK.WriteTo(buf) + assert.NoError(err) + _, err = nativePK.ReadFrom(buf) + assert.NoError(err) + if zePK.IsDifferent(&nativePK) { + t.Error("marshal output difference") + } +} diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go new file mode 100644 index 0000000000..a1c94bb97b --- /dev/null +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -0,0 +1,18 @@ +//go:build !zeknox + +package zeknox_bn254 + +import ( + "fmt" + + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/witness" + cs "github.com/consensys/gnark/constraint/bn254" +) + +const HasZeknox = false + +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + return nil, fmt.Errorf("zeknox backend requested but program compiled without 'zeknox' build tag") +} diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go new file mode 100644 index 0000000000..b13c59b0d1 --- /dev/null +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -0,0 +1,36 @@ +package zeknox_bn254 + +import ( + "unsafe" + + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + cs "github.com/consensys/gnark/constraint/bn254" +) + +type deviceInfo struct { + G1Device struct { + A, B, K, Z unsafe.Pointer + } + DomainDevice struct { + Twiddles, TwiddlesInv unsafe.Pointer + CosetTable, CosetTableInv unsafe.Pointer + } + G2Device struct { + B unsafe.Pointer + } + DenDevice unsafe.Pointer + InfinityPointIndicesK []int +} + +type ProvingKey struct { + groth16_bn254.ProvingKey + *deviceInfo +} + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index a56b5730a3..6f16787fe6 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -51,6 +51,7 @@ import ( groth16_bls24317 "github.com/consensys/gnark/backend/groth16/bls24-317" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" icicle_bn254 "github.com/consensys/gnark/backend/groth16/bn254/icicle" + zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" groth16_bw6633 "github.com/consensys/gnark/backend/groth16/bw6-633" groth16_bw6761 "github.com/consensys/gnark/backend/groth16/bw6-761" ) @@ -198,6 +199,9 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness. return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), fullWitness, opts...) case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + return zeknox_bn254.Prove(_r1cs, pk.(*zeknox_bn254.ProvingKey), fullWitness, opts...) + } if icicle_bn254.HasIcicle { return icicle_bn254.Prove(_r1cs, pk.(*icicle_bn254.ProvingKey), fullWitness, opts...) } @@ -247,6 +251,13 @@ func Setup(r1cs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { return &pk, &vk, nil case *cs_bn254.R1CS: var vk groth16_bn254.VerifyingKey + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.Setup(_r1cs, &pk, &vk); err != nil { + return nil, nil, err + } + return &pk, &vk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.Setup(_r1cs, &pk, &vk); err != nil { @@ -309,6 +320,13 @@ func DummySetup(r1cs constraint.ConstraintSystem) (ProvingKey, error) { } return &pk, nil case *cs_bn254.R1CS: + if zeknox_bn254.HasZeknox { + var pk zeknox_bn254.ProvingKey + if err := zeknox_bn254.DummySetup(_r1cs, &pk); err != nil { + return nil, err + } + return &pk, nil + } if icicle_bn254.HasIcicle { var pk icicle_bn254.ProvingKey if err := icicle_bn254.DummySetup(_r1cs, &pk); err != nil { @@ -357,6 +375,9 @@ func NewProvingKey(curveID ecc.ID) ProvingKey { switch curveID { case ecc.BN254: pk = &groth16_bn254.ProvingKey{} + if zeknox_bn254.HasZeknox { + pk = &zeknox_bn254.ProvingKey{} + } if icicle_bn254.HasIcicle { pk = &icicle_bn254.ProvingKey{} } From 2da1853ddd5f1175c54a4bf36edc3124416b3cb3 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 16 Oct 2024 14:32:17 +0800 Subject: [PATCH 29/62] MSM G1 & G2 acclerating! with local cuda repo --- backend/groth16/bn254/zeknox/provingkey.go | 6 +- backend/groth16/bn254/zeknox/zeknox.go | 552 +++++++++++++++++++++ go.mod | 8 +- go.sum | 4 +- 4 files changed, 563 insertions(+), 7 deletions(-) create mode 100644 backend/groth16/bn254/zeknox/zeknox.go diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index b13c59b0d1..f54a97a669 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -4,19 +4,21 @@ import ( "unsafe" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254" cs "github.com/consensys/gnark/constraint/bn254" + "github.com/okx/cryptography_cuda/wrappers/go/device" ) type deviceInfo struct { G1Device struct { - A, B, K, Z unsafe.Pointer + A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] } DomainDevice struct { Twiddles, TwiddlesInv unsafe.Pointer CosetTable, CosetTableInv unsafe.Pointer } G2Device struct { - B unsafe.Pointer + B *device.HostOrDeviceSlice[bn254.G2Affine] } DenDevice unsafe.Pointer InfinityPointIndicesK []int diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go new file mode 100644 index 0000000000..e49c9aa031 --- /dev/null +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -0,0 +1,552 @@ +//go:build zeknox + +package zeknox_bn254 + +import ( + "context" + "fmt" + "math/big" + "runtime" + "time" + "unsafe" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field" + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/constraint/solver" + fcs "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + "github.com/okx/cryptography_cuda/wrappers/go/device" + "github.com/okx/cryptography_cuda/wrappers/go/msm" + "golang.org/x/sync/errgroup" +) + +const HasZeknox = true + +// Use single GPU +const deviceId = 0 + +func (pk *ProvingKey) setupDevicePointers() error { + if pk.deviceInfo != nil { + return nil + } + pk.deviceInfo = &deviceInfo{} + // TODO: setup FFT + + // MSM G1 & G2 Device Setup + g, _ := errgroup.WithContext(context.TODO()) + // G1.A + deviceA := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.A, deviceA) }) + + // G1.B + deviceG1B := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.B, deviceG1B) }) + + // G1.K + var pointsNoInfinity []curve.G1Affine + for i, gnarkPoint := range pk.G1.K { + if gnarkPoint.IsInfinity() { + pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i) + } else { + pointsNoInfinity = append(pointsNoInfinity, gnarkPoint) + } + } + deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) + + // G1.Z + deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G1.Z, deviceZ) }) + + // G2.B + deviceG2B := make(chan *device.HostOrDeviceSlice[curve.G2Affine], 1) + g.Go(func() error { return CopyToDevice(pk.G2.B, deviceG2B) }) + + // wait for all points to be copied to the device + // if any of the copy failed, return the error + if err := g.Wait(); err != nil { + return err + } + // if no error, store device pointers in pk + pk.G1Device.A = <-deviceA + pk.G1Device.B = <-deviceG1B + pk.G1Device.K = <-deviceK + pk.G1Device.Z = <-deviceZ + pk.G2Device.B = <-deviceG2B + + return nil +} + +// Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { + fmt.Println("zeknox_bn254.Prove") + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, fmt.Errorf("new prover config: %w", err) + } + if opt.HashToFieldFn == nil { + opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst)) + } + if opt.Accelerator != "zeknox" { + return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) + } + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + if pk.deviceInfo == nil { + log.Debug().Msg("precomputing proving key in GPU") + if err := pk.setupDevicePointers(); err != nil { + return nil, fmt.Errorf("setup device pointers: %w", err) + } + } + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &groth16_bn254.Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] + + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + + // override hints + bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder) + solverOpts = append(solverOpts, solver.OverrideHint(bsb22ID, func(_ *big.Int, in []*big.Int, out []*big.Int) error { + i := int(in[0].Int64()) + in = in[1:] + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[+len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1)) + hashBts := opt.HashToFieldFn.Sum(nil) + opt.HashToFieldFn.Reset() + nbBuf := fr.Bytes + if opt.HashToFieldFn.Size() < fr.Bytes { + nbBuf = opt.HashToFieldFn.Size() + } + var res fr.Element + res.SetBytes(hashBts[:nbBuf]) + res.BigInt(out[0]) + return nil + })) + + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err + } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + + start := time.Now() + poks := make([]curve.G1Affine, len(pk.CommitmentKeys)) + + for i := range pk.CommitmentKeys { + var err error + if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil { + return nil, err + } + } + // compute challenge for folding the PoKs from the commitments + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1) + if err != nil { + return nil, err + } + if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil { + return nil, err + } + + // quotient poly H (witness reduction / FFT part) + var h []fr.Element + chHDone := make(chan struct{}, 1) + go func() { + startH := time.Now() + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") + solution.A = nil + solution.B = nil + solution.C = nil + chHDone <- struct{}{} + }() + + // we need to copy and filter the wireValues for each multi exp + // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity + var deviceWireValuesA, deviceWireValuesB *device.HostOrDeviceSlice[fr.Element] + // indicate if the wire values have been copied to the device + chWireValuesA, chWireValuesB := make(chan error, 1), make(chan error, 1) + + go func() { + wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + for i, j := 0, 0; j < len(wireValuesA); i++ { + if pk.InfinityA[i] { + continue + } + wireValuesA[j] = wireValues[i] + j++ + } + chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { + chWireValuesA <- err + close(chWireValuesA) + return + } + deviceWireValuesA = <-chDeviceValues + close(chWireValuesA) + }() + go func() { + wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + for i, j := 0, 0; j < len(wireValuesB); i++ { + if pk.InfinityB[i] { + continue + } + wireValuesB[j] = wireValues[i] + j++ + } + chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { + chWireValuesB <- err + close(chWireValuesB) + return + } + deviceWireValuesB = <-chDeviceValues + close(chWireValuesB) + }() + + // sample random r and s + var r, s big.Int + var _r, _s, _kr fr.Element + if _, err := _r.SetRandom(); err != nil { + return nil, err + } + if _, err := _s.SetRandom(); err != nil { + return nil, err + } + // -rs + // Why it is called kr? not rs? -> notation from DIZK paper + _kr.Mul(&_r, &_s).Neg(&_kr) + + _r.BigInt(&r) + _s.BigInt(&s) + + // computes r[δ], s[δ], kr[δ] + deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) + + var bs1, ar curve.G1Jac + + n := runtime.NumCPU() + + chBs1Done := make(chan error, 1) + + computeBS1 := func() { + if err := <-chWireValuesB; err != nil { + chBs1Done <- err + close(chBs1Done) + return + } + startBs1 := time.Now() + if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { + chBs1Done <- err + close(chBs1Done) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1.MultiExp done") + // + beta + s[δ] + bs1.AddMixed(&pk.G1.Beta) + bs1.AddMixed(&deltas[1]) + chBs1Done <- nil + } + + chArDone := make(chan error, 1) + computeAR1 := func() { + if err := <-chWireValuesA; err != nil { + chArDone <- err + close(chArDone) + return + } + startAr := time.Now() + if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { + chArDone <- err + close(chArDone) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar.MultiExp done") + ar.AddMixed(&pk.G1.Alpha) + ar.AddMixed(&deltas[0]) + proof.Ar.FromJacobian(&ar) + chArDone <- nil + } + + chKrsDone := make(chan error, 1) + var deviceH *device.HostOrDeviceSlice[fr.Element] + computeKRS := func() { + // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z + // however, having similar lengths for our tasks helps with parallelism + + var krs, krs2, p1 curve.G1Jac + chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + go func() { + startKrs2 := time.Now() + // Copy h poly to device, since we haven't implemented FFT on device + chDevice := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(h[:sizeH], chDevice); err != nil { + chKrs2Done <- err + close(chKrs2Done) + return + } + deviceH = <-chDevice + if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + chKrs2Done <- err + close(chKrs2Done) + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") + chKrs2Done <- err + }() + + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + + startKrs := time.Now() + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + chKrsDone <- err + return + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") + // -rs[δ] + krs.AddMixed(&deltas[2]) + n := 3 + for n != 0 { + select { + case err := <-chKrs2Done: + if err != nil { + chKrsDone <- err + return + } + krs.AddAssign(&krs2) + case err := <-chArDone: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + case err := <-chBs1Done: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) + } + n-- + } + + proof.Krs.FromJacobian(&krs) + chKrsDone <- nil + } + + computeBS2 := func() error { + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac + + nbTasks := n + if nbTasks <= 16 { + // if we don't have a lot of CPUs, this may artificially split the MSM + nbTasks *= 2 + } + <-chWireValuesB + startBs := time.Now() + if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs.MultiExp done") + + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) + + proof.Bs.FromJacobian(&Bs) + return nil + } + + // wait for FFT to end, as it uses all our CPUs + <-chHDone + + // schedule our proof part computations + go computeKRS() + go computeAR1() + go computeBS1() + if err := computeBS2(); err != nil { + return nil, err + } + + // wait for all parts of the proof to be computed. + if err := <-chKrsDone; err != nil { + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("prover done") + + // Free device memory + go func() { + deviceWireValuesA.Free() + deviceWireValuesB.Free() + deviceH.Free() + }() + + return proof, nil +} + +// if len(toRemove) == 0, returns slice +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { + + if len(toRemove) == 0 { + return slice + } + + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) + for i := 0; i < len(slice); i++ { + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } + continue + } + r = append(r, slice[i]) + } + + return +} + +func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { + // H part of Krs + // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) + // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) + // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) + // 3 - h = ifft_coset(ca o cb - cc) + + n := len(a) + + // add padding to ensure input length is domain cardinality + padding := make([]fr.Element, int(domain.Cardinality)-n) + a = append(a, padding...) + b = append(b, padding...) + c = append(c, padding...) + n = len(a) + + // a -> aPoly, b -> bPoly, c -> cPoly + // point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) + + // evaluate aPoly, bPoly, cPoly on coset (roots of unity) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) + + // vanishing poly t(x) = x^N - 1 + // calcualte 1/t(g), where g is the generator + var den, one fr.Element + one.SetOne() + // g^N + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + // 1/(g^N - 1) + den.Sub(&den, &one).Inverse(&den) + + // h = (a*b - c)/t + // h = ifft_coset(ca o cb - cc) + // reusing a to avoid unnecessary memory allocation + utils.Parallelize(n, func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &b[i]). + Sub(&a[i], &c[i]). + Mul(&a[i], &den) + } + }) + + // ifft_coset: point-value form -> coefficient form + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) + + return a +} + +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + if points.Len() != scalars.Len() { + return fmt.Errorf("MSM: len(points) != len(scalars)") + } + cfg := msm.DefaultMSMConfig() + cfg.ArePointsInMont = true + cfg.Npoints = uint32(points.Len()) + cfg.FfiAffineSz = 64 + if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + return nil +} + +func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + if points.Len() != scalars.Len() { + return fmt.Errorf("MSM: len(points) != len(scalars)") + } + cfg := msm.DefaultMSMConfig() + cfg.AreInputsOnDevice = true + cfg.ArePointsInMont = true + cfg.Npoints = uint32(points.Len()) + cfg.LargeBucketFactor = 2 + // TODO: MSM_G2 should return Jacobian + // https://github.com/okx/cryptography_cuda/issues/90 + resAffine := curve.G2Affine{} + if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + res.FromAffine(&resAffine) + return nil +} + +func CopyToDevice[T any](hostData []T, chDeviceSlice chan *device.HostOrDeviceSlice[T]) error { + deviceSlice, err := device.CudaMalloc[T](deviceId, len(hostData)) + if err != nil { + chDeviceSlice <- nil + return err + } + if err := deviceSlice.CopyFromHost(hostData[:]); err != nil { + chDeviceSlice <- nil + return err + } + chDeviceSlice <- deviceSlice + return nil +} diff --git a/go.mod b/go.mod index 4ed242d670..e643586c0c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/consensys/gnark -go 1.22 +go 1.22.2 -toolchain go1.22.6 +toolchain go1.23.1 require ( github.com/bits-and-blooms/bitset v1.14.2 @@ -16,6 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 + github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -31,9 +32,10 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/x448/float16 v0.8.4 // indirect golang.org/x/sys v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go diff --git a/go.sum b/go.sum index efc71ddf96..81dc4a651b 100644 --- a/go.sum +++ b/go.sum @@ -241,8 +241,8 @@ github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndr github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/ronanh/intcomp v1.1.0 h1:i54kxmpmSoOZFcWPMWryuakN0vLxLswASsGa07zkvLU= github.com/ronanh/intcomp v1.1.0/go.mod h1:7FOLy3P3Zj3er/kVrU/pl+Ql7JFZj7bwliMGketo0IU= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= From 70d36b275155073cd4291fd057c3e6f62b5ab5d3 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 17 Oct 2024 11:53:09 +0800 Subject: [PATCH 30/62] sequencial GPU MSM & refactor --- backend/groth16/bn254/zeknox/nozeknox.go | 2 +- backend/groth16/bn254/zeknox/zeknox.go | 97 +++++++++++++++--------- go.mod | 2 +- go.sum | 2 + 4 files changed, 64 insertions(+), 39 deletions(-) diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index a1c94bb97b..8859d6f319 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -1,4 +1,4 @@ -//go:build !zeknox +//go:build zeknox package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index e49c9aa031..df967e908d 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -1,4 +1,4 @@ -//go:build zeknox +//go:build !zeknox package zeknox_bn254 @@ -89,7 +89,6 @@ func (pk *ProvingKey) setupDevicePointers() error { // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { - fmt.Println("zeknox_bn254.Prove") opt, err := backend.NewProverConfig(opts...) if err != nil { return nil, fmt.Errorf("new prover config: %w", err) @@ -102,10 +101,11 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() if pk.deviceInfo == nil { - log.Debug().Msg("precomputing proving key in GPU") + start := time.Now() if err := pk.setupDevicePointers(); err != nil { return nil, fmt.Errorf("setup device pointers: %w", err) } + log.Debug().Dur("took", time.Since(start)).Msg("Copy proving key to device") } commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) @@ -182,7 +182,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b go func() { startH := time.Now() h = computeH(solution.A, solution.B, solution.C, &pk.Domain) - log.Debug().Dur("computeH took", time.Since(startH)).Msg("computed H") + log.Debug().Dur("took", time.Since(startH)).Msg("computed H") solution.A = nil solution.B = nil solution.C = nil @@ -207,7 +207,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { chWireValuesA <- err - close(chWireValuesA) return } deviceWireValuesA = <-chDeviceValues @@ -225,7 +224,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { chWireValuesB <- err - close(chWireValuesB) return } deviceWireValuesB = <-chDeviceValues @@ -253,23 +251,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - n := runtime.NumCPU() - chBs1Done := make(chan error, 1) computeBS1 := func() { if err := <-chWireValuesB; err != nil { chBs1Done <- err - close(chBs1Done) return } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { chBs1Done <- err - close(chBs1Done) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) @@ -280,16 +274,14 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b computeAR1 := func() { if err := <-chWireValuesA; err != nil { chArDone <- err - close(chArDone) return } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { chArDone <- err - close(chArDone) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) @@ -304,40 +296,56 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) - sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { startKrs2 := time.Now() // Copy h poly to device, since we haven't implemented FFT on device - chDevice := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(h[:sizeH], chDevice); err != nil { + chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { chKrs2Done <- err - close(chKrs2Done) return } - deviceH = <-chDevice + deviceH = <-chDeviceH if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { chKrs2Done <- err - close(chKrs2Done) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") - chKrs2Done <- err + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") + close(chKrs2Done) }() // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) - _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + // original Groth16 witness without pedersen commitment + wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) startKrs := time.Now() - if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + // GPU runtime error + // var deviceWire *device.HostOrDeviceSlice[fr.Element] + // defer deviceWire.Free() + // chDeviceWire := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + // if err := CopyToDevice(wireValuesWithoutCom, chDeviceWire); err != nil { + // chKrsDone <- err + // return + // } + // deviceWire = <-chDeviceWire + // if err := msmG1(&krs, pk.G1Device.K, deviceWire); err != nil { + // chKrsDone <- err + // return + // } + + // CPU + // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { chKrsDone <- err return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) + n := 3 for n != 0 { select { @@ -373,17 +381,14 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - nbTasks := n - if nbTasks <= 16 { - // if we don't have a lot of CPUs, this may artificially split the MSM - nbTasks *= 2 + if err := <-chWireValuesB; err != nil { + return err } - <-chWireValuesB startBs := time.Now() if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { return err } - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs.MultiExp done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -398,17 +403,34 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b <-chHDone // schedule our proof part computations - go computeKRS() - go computeAR1() - go computeBS1() - if err := computeBS2(); err != nil { + // Sequencial GPU execution + // TODO: see GPU utilization data + computeAR1() + if err := <-chArDone; err != nil { return nil, err } - - // wait for all parts of the proof to be computed. + computeBS1() + if err := <-chBs1Done; err != nil { + return nil, err + } + computeKRS() if err := <-chKrsDone; err != nil { return nil, err } + if err := computeBS2(); err != nil { + return nil, err + } + + // Parallel GPU execution, memory may hit limit + // go computeKRS() + // go computeAR1() + // go computeBS1() + // go computeBS2() + + // wait for all parts of the proof to be computed. + // if err := <-chKrsDone; err != nil { + // return nil, err + // } log.Debug().Dur("took", time.Since(start)).Msg("prover done") @@ -423,6 +445,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // if len(toRemove) == 0, returns slice +// // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex // this assumes len(slice) > len(toRemove) // filterHeap modifies toRemove diff --git a/go.mod b/go.mod index e643586c0c..52c8797024 100644 --- a/go.mod +++ b/go.mod @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +// replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go diff --git a/go.sum b/go.sum index 81dc4a651b..fc472ffae3 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e h1:NT/U7+AJ93s0U4af9I5fEtpE33Etf68wEUif7Q/s1mo= +github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e/go.mod h1:y9SSivg7t0Fs0PZQJ/l2jUhWT67SeEj9XYgz5ysjyEw= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From be85d2a00b127844732bc5a71cc90562c6167330 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:27:54 +0800 Subject: [PATCH 31/62] fix verify bug, delete channel --- backend/groth16/bn254/zeknox/zeknox.go | 86 +++++++------------------- 1 file changed, 22 insertions(+), 64 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index df967e908d..780fbb2a23 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -251,46 +251,38 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - chBs1Done := make(chan error, 1) - - computeBS1 := func() { + computeBS1 := func() error { if err := <-chWireValuesB; err != nil { - chBs1Done <- err - return + return err } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - chBs1Done <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - chBs1Done <- nil + return nil } - chArDone := make(chan error, 1) - computeAR1 := func() { + computeAR1 := func() error { if err := <-chWireValuesA; err != nil { - chArDone <- err - return + return err } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - chArDone <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - chArDone <- nil + return nil } - chKrsDone := make(chan error, 1) var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() { + computeKRS := func() error { // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z // however, having similar lengths for our tasks helps with parallelism @@ -321,7 +313,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // original Groth16 witness without pedersen commitment wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - startKrs := time.Now() // GPU runtime error // var deviceWire *device.HostOrDeviceSlice[fr.Element] // defer deviceWire.Free() @@ -338,43 +329,25 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // CPU // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - chKrsDone <- err - return + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) - n := 3 - for n != 0 { - select { - case err := <-chKrs2Done: - if err != nil { - chKrsDone <- err - return - } - krs.AddAssign(&krs2) - case err := <-chArDone: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - case err := <-chBs1Done: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) - } - n-- + if err := <-chKrs2Done; err != nil { + return err } + krs.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) proof.Krs.FromJacobian(&krs) - chKrsDone <- nil + return nil } computeBS2 := func() error { @@ -405,33 +378,18 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // schedule our proof part computations // Sequencial GPU execution // TODO: see GPU utilization data - computeAR1() - if err := <-chArDone; err != nil { + if err := computeAR1(); err != nil { return nil, err } - computeBS1() - if err := <-chBs1Done; err != nil { + if err := computeBS1(); err != nil { return nil, err } - computeKRS() - if err := <-chKrsDone; err != nil { + if err := computeKRS(); err != nil { return nil, err } if err := computeBS2(); err != nil { return nil, err } - - // Parallel GPU execution, memory may hit limit - // go computeKRS() - // go computeAR1() - // go computeBS1() - // go computeBS2() - - // wait for all parts of the proof to be computed. - // if err := <-chKrsDone; err != nil { - // return nil, err - // } - log.Debug().Dur("took", time.Since(start)).Msg("prover done") // Free device memory From 68f6664f892fe4d96e0f5e8a7ca80505532eee36 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:15:55 +0800 Subject: [PATCH 32/62] parallel but verify fail in most cases --- backend/groth16/bn254/zeknox/zeknox.go | 99 +++++++++++++++++--------- 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 780fbb2a23..f81d48fcca 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -251,38 +251,45 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - computeBS1 := func() error { + chBs1Done := make(chan error, 1) + computeBS1 := func() { if err := <-chWireValuesB; err != nil { - return err + chBs1Done <- err + return } startBs1 := time.Now() if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - return err + chBs1Done <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - return nil + chBs1Done <- nil } - computeAR1 := func() error { + chArDone := make(chan error, 1) + computeAR1 := func() { if err := <-chWireValuesA; err != nil { - return err + chArDone <- err + return } startAr := time.Now() if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - return err + chArDone <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - return nil + chArDone <- nil } + chKrsDone := make(chan error, 1) var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() error { + computeKRS := func() { // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z // however, having similar lengths for our tasks helps with parallelism @@ -329,37 +336,63 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // CPU // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU + // Also, reduce data copy startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - return err + chKrsDone <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] krs.AddMixed(&deltas[2]) - if err := <-chKrs2Done; err != nil { - return err + n := 3 + for n != 0 { + select { + // wait krs2 + case err := <-chKrs2Done: + if err != nil { + chKrsDone <- err + return + } + krs.AddAssign(&krs2) + // wait ar + case err := <-chArDone: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + // wait bs1 + case err := <-chBs1Done: + if err != nil { + chKrsDone <- err + return + } + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) + } + n-- } - krs.AddAssign(&krs2) - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) proof.Krs.FromJacobian(&krs) - return nil + chKrsDone <- nil } - computeBS2 := func() error { + chBs2Done := make(chan error, 1) + computeBS2 := func() { // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac if err := <-chWireValuesB; err != nil { - return err + chBs2Done <- err + return } startBs := time.Now() if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { - return err + chBs2Done <- err + return } log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") @@ -369,27 +402,27 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b Bs.AddMixed(&pk.G2.Beta) proof.Bs.FromJacobian(&Bs) - return nil + chBs2Done <- nil } // wait for FFT to end, as it uses all our CPUs <-chHDone - // schedule our proof part computations - // Sequencial GPU execution - // TODO: see GPU utilization data - if err := computeAR1(); err != nil { - return nil, err - } - if err := computeBS1(); err != nil { - return nil, err - } - if err := computeKRS(); err != nil { + // Parallel GPU execution, memory may hit limit + go computeAR1() + go computeBS1() + go computeBS2() + go computeKRS() + + // wait for all parts of the proof to be computed. + // Krs done means AR1, BS1 are done + if err := <-chKrsDone; err != nil { return nil, err } - if err := computeBS2(); err != nil { + if err := <-chBs2Done; err != nil { return nil, err } + log.Debug().Dur("took", time.Since(start)).Msg("prover done") // Free device memory From 3bd0c168639c322667a53f691776d92d44948f1d Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:20:48 +0800 Subject: [PATCH 33/62] fix msm cfg & add input check --- backend/groth16/bn254/zeknox/zeknox.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index f81d48fcca..50f36b07ff 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -411,14 +411,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // Parallel GPU execution, memory may hit limit go computeAR1() go computeBS1() - go computeBS2() go computeKRS() - // wait for all parts of the proof to be computed. - // Krs done means AR1, BS1 are done + // wait krs, ar1, bs1 + // krs done means AR1, BS1 are done if err := <-chKrsDone; err != nil { return nil, err } + // bs2 and bs1 both depend on wireValuesB + computeBS2() if err := <-chBs2Done; err != nil { return nil, err } @@ -518,11 +519,20 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *device.HostOrDeviceSlice[S]) error { + if !points.IsOnDevice() || !scalars.IsOnDevice() { + return fmt.Errorf("MSM: points and scalars must be on device") + } if points.Len() != scalars.Len() { return fmt.Errorf("MSM: len(points) != len(scalars)") } + return nil +} + +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { + checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() + cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.FfiAffineSz = 64 @@ -533,9 +543,7 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s } func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - if points.Len() != scalars.Len() { - return fmt.Errorf("MSM: len(points) != len(scalars)") - } + checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true From 3f8da269d70bbad8f59e5737bc5c7f3d45102ead Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:10:24 +0800 Subject: [PATCH 34/62] fix parallel GPU proving, use errgroup --- backend/groth16/bn254/zeknox/zeknox.go | 253 +++++++++---------------- 1 file changed, 94 insertions(+), 159 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 50f36b07ff..76dd72af7a 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -176,27 +176,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil, err } - // quotient poly H (witness reduction / FFT part) - var h []fr.Element - chHDone := make(chan struct{}, 1) - go func() { - startH := time.Now() - h = computeH(solution.A, solution.B, solution.C, &pk.Domain) - log.Debug().Dur("took", time.Since(startH)).Msg("computed H") - solution.A = nil - solution.B = nil - solution.C = nil - chHDone <- struct{}{} - }() - // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity - var deviceWireValuesA, deviceWireValuesB *device.HostOrDeviceSlice[fr.Element] - // indicate if the wire values have been copied to the device - chWireValuesA, chWireValuesB := make(chan error, 1), make(chan error, 1) + var wireValuesA, wireValuesB []fr.Element + chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) go func() { - wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) for i, j := 0, 0; j < len(wireValuesA); i++ { if pk.InfinityA[i] { continue @@ -204,16 +190,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesA[j] = wireValues[i] j++ } - chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesA, chDeviceValues); err != nil { - chWireValuesA <- err - return - } - deviceWireValuesA = <-chDeviceValues close(chWireValuesA) }() go func() { - wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) for i, j := 0, 0; j < len(wireValuesB); i++ { if pk.InfinityB[i] { continue @@ -221,12 +201,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireValuesB[j] = wireValues[i] j++ } - chDeviceValues := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesB, chDeviceValues); err != nil { - chWireValuesB <- err - return - } - deviceWireValuesB = <-chDeviceValues close(chWireValuesB) }() @@ -251,68 +225,78 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - chBs1Done := make(chan error, 1) - computeBS1 := func() { - if err := <-chWireValuesB; err != nil { - chBs1Done <- err - return + computeBS1 := func() error { + <- chWireValuesB + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err } + wireB = <-chWireB + defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, pk.G1Device.B, deviceWireValuesB); err != nil { - chBs1Done <- err - return + if err := msmG1(&bs1, pk.G1Device.B, wireB); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesB.Len()), time.Since(startBs1)).Msg("bs1 done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) - chBs1Done <- nil + return nil } - chArDone := make(chan error, 1) - computeAR1 := func() { - if err := <-chWireValuesA; err != nil { - chArDone <- err - return + computeAR1 := func() error { + <- chWireValuesA + var wireA *device.HostOrDeviceSlice[fr.Element] + chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesA, chWireA); err != nil { + return err } + wireA = <-chWireA + defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, pk.G1Device.A, deviceWireValuesA); err != nil { - chArDone <- err - return + if err := msmG1(&ar, pk.G1Device.A, wireA); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", deviceWireValuesA.Len()), time.Since(startAr)).Msg("ar done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) - chArDone <- nil + return nil } - chKrsDone := make(chan error, 1) - var deviceH *device.HostOrDeviceSlice[fr.Element] - computeKRS := func() { - // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z - // however, having similar lengths for our tasks helps with parallelism - - var krs, krs2, p1 curve.G1Jac - chKrs2Done := make(chan error, 1) - go func() { - startKrs2 := time.Now() - // Copy h poly to device, since we haven't implemented FFT on device - chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 - if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { - chKrs2Done <- err - return - } - deviceH = <-chDeviceH - if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { - chKrs2Done <- err - return - } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") - close(chKrs2Done) - }() + var krs2 curve.G1Jac + computeKRS2 := func() error { + // quotient poly H (witness reduction / FFT part) + var h []fr.Element + { + startH := time.Now() + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + log.Debug().Dur("took", time.Since(startH)).Msg("computed H") + solution.A = nil + solution.B = nil + solution.C = nil + } + // Copy h poly to device, since we haven't implemented FFT on device + var deviceH *device.HostOrDeviceSlice[fr.Element] + chDeviceH := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 + if err := CopyToDevice(h[:sizeH], chDeviceH); err != nil { + return err + } + deviceH = <-chDeviceH + defer deviceH.Free() + // MSM G1 Krs2 + startKrs2 := time.Now() + if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") + return nil + } + var krs1 curve.G1Jac + computeKRS1 := func() error { // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() @@ -320,81 +304,35 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // original Groth16 witness without pedersen commitment wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - // GPU runtime error - // var deviceWire *device.HostOrDeviceSlice[fr.Element] - // defer deviceWire.Free() - // chDeviceWire := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - // if err := CopyToDevice(wireValuesWithoutCom, chDeviceWire); err != nil { - // chKrsDone <- err - // return - // } - // deviceWire = <-chDeviceWire - // if err := msmG1(&krs, pk.G1Device.K, deviceWire); err != nil { - // chKrsDone <- err - // return - // } - // CPU - // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU - // Also, reduce data copy + // Compute this MSM on CPU, as it can be done in parallel with other MSM on GPU, also reduce data copy startKrs := time.Now() - if _, err := krs.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { - chKrsDone <- err - return + if _, err := krs1.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { + return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") // -rs[δ] - krs.AddMixed(&deltas[2]) - - n := 3 - for n != 0 { - select { - // wait krs2 - case err := <-chKrs2Done: - if err != nil { - chKrsDone <- err - return - } - krs.AddAssign(&krs2) - // wait ar - case err := <-chArDone: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - // wait bs1 - case err := <-chBs1Done: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) - } - n-- - } - - proof.Krs.FromJacobian(&krs) - chKrsDone <- nil + krs1.AddMixed(&deltas[2]) + return nil } - chBs2Done := make(chan error, 1) - computeBS2 := func() { + computeBS2 := func() error { + <-chWireValuesB // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - if err := <-chWireValuesB; err != nil { - chBs2Done <- err - return + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err } + wireB = <-chWireB + defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, pk.G2Device.B, deviceWireValuesB); err != nil { - chBs2Done <- err - return + if err := msmG2(&Bs, pk.G2Device.B, wireB); err != nil { + return err } - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", deviceWireValuesB.Len()), time.Since(startBs)).Msg("Bs done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -402,37 +340,34 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b Bs.AddMixed(&pk.G2.Beta) proof.Bs.FromJacobian(&Bs) - chBs2Done <- nil + return nil } - // wait for FFT to end, as it uses all our CPUs - <-chHDone - - // Parallel GPU execution, memory may hit limit - go computeAR1() - go computeBS1() - go computeKRS() + // Parallel execution, memory may hit limit + g, _ := errgroup.WithContext(context.TODO()) + g.Go(computeAR1) + g.Go(computeBS1) + g.Go(computeKRS1) + g.Go(computeKRS2) + g.Go(computeBS2) - // wait krs, ar1, bs1 - // krs done means AR1, BS1 are done - if err := <-chKrsDone; err != nil { + if err := g.Wait(); err != nil { return nil, err } - // bs2 and bs1 both depend on wireValuesB - computeBS2() - if err := <-chBs2Done; err != nil { - return nil, err + + // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 + { + var p1 curve.G1Jac + krs1.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs1.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs1.AddAssign(&p1) + proof.Krs.FromJacobian(&krs1) } log.Debug().Dur("took", time.Since(start)).Msg("prover done") - // Free device memory - go func() { - deviceWireValuesA.Free() - deviceWireValuesB.Free() - deviceH.Free() - }() - return proof, nil } From 91f00e7faca9cb6c9467d59b0a667ca8362b5cd7 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:29:55 +0800 Subject: [PATCH 35/62] small fix --- backend/groth16/bn254/zeknox/zeknox.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 76dd72af7a..6c60e081ff 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -310,7 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b if _, err := krs1.MultiExp(pk.G1.K, wireValuesWithoutCom, ecc.MultiExpConfig{NbTasks: runtime.NumCPU() / 2}); err != nil { return err } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("krs done") + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("CPU krs done") // -rs[δ] krs1.AddMixed(&deltas[2]) return nil From 1faa8ba0163e858f5d7323dd033e5c64848c83ac Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:49:08 +0800 Subject: [PATCH 36/62] add doc --- README.md | 30 +++++++++++++++++++++++++++++- go.mod | 4 ++-- go.sum | 6 ++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 351ac156f3..92786166c4 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,34 @@ func main() { ### GPU Support +#### Zeknox Library +Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/cryptography_cuda) + +##### Download prebuilt binary +```sh +sudo cp libblst.a libcryptocuda.a /usr/local/lib/ +``` + +If you want to build from source, see guide in https://github.com/okx/cryptography_cuda + +##### Enjoy GPU +Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` + +Test +```go +assert.ProverSucceeded(&mimcCircuit, &Circuit{ + PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", + Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", + }, test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) +``` + +```sh +go run -tags=zeknox examples/main.go +# (place -tags before the filename) + +go test github.com/consensys/gnark/examples/mimc -tags=prover_checks,zeknox +``` + #### Icicle Library The following schemes and curves support experimental use of Ingonyama's Icicle GPU library for low level zk-SNARK primitives such as MSM, NTT, and polynomial operations: @@ -178,7 +206,7 @@ You can then toggle on or off icicle acceleration by providing the `WithIcicleAc ```go // toggle on proofIci, err := groth16.Prove(ccs, pk, secretWitness, backend.WithIcicleAcceleration()) - + // toggle off proof, err := groth16.Prove(ccs, pk, secretWitness) ``` diff --git a/go.mod b/go.mod index 52c8797024..06e14c02e1 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e + github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -// replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go diff --git a/go.sum b/go.sum index fc472ffae3..9cca00bced 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,10 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e h1:NT/U7+AJ93s0U4af9I5fEtpE33Etf68wEUif7Q/s1mo= -github.com/okx/cryptography_cuda/wrappers/go v0.0.0-20241016023422-25c1f0f5f44e/go.mod h1:y9SSivg7t0Fs0PZQJ/l2jUhWT67SeEj9XYgz5ysjyEw= +github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868 h1:aPaETd6bRKs2VpM8C9bZrOJprtUMIN2MXQIcCOtovX8= +github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 h1:c3aLIA4Wje6nGEx4XksWkwRI0U6kC9ITXa7ZBp6d5DU= +github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From ea745dd5fe4230687d81564be99248d803f6fb5e Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Tue, 22 Oct 2024 09:55:42 +0800 Subject: [PATCH 37/62] set msm LargeBucketFactor config --- backend/groth16/bn254/zeknox/zeknox.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 6c60e081ff..62417cc78e 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -470,7 +470,7 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s cfg.AreInputsOnDevice = true cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) - cfg.FfiAffineSz = 64 + cfg.LargeBucketFactor = 2 if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } From 915559f2a6b9f569c1c0fc2831324e5512fa2400 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:05:20 +0800 Subject: [PATCH 38/62] update msmg1, msmg1 return affine --- backend/groth16/bn254/zeknox/zeknox.go | 6 +++--- go.mod | 2 +- go.sum | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 62417cc78e..5513684775 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -471,9 +471,11 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - if err := msm.MSM_G1(unsafe.Pointer(res), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + resAffine := curve.G1Affine{} + if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + res.FromAffine(&resAffine) return nil } @@ -484,8 +486,6 @@ func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], s cfg.ArePointsInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - // TODO: MSM_G2 should return Jacobian - // https://github.com/okx/cryptography_cuda/issues/90 resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err diff --git a/go.mod b/go.mod index 06e14c02e1..53ef90659c 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 + github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 9cca00bced..8da0c561d1 100644 --- a/go.sum +++ b/go.sum @@ -230,10 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868 h1:aPaETd6bRKs2VpM8C9bZrOJprtUMIN2MXQIcCOtovX8= -github.com/okx/cryptography_cuda v0.0.0-20241018104030-628693daf868/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= -github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28 h1:c3aLIA4Wje6nGEx4XksWkwRI0U6kC9ITXa7ZBp6d5DU= -github.com/okx/cryptography_cuda v0.0.0-20241018104554-bafea0c91f28/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 h1:HgiJDIO/n8DTRCTRaw7CYm042Ieyo00O7wD90ZUteO0= +github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 134f437c859b7dd0f2445373b319112811bfd379 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:59:22 +0800 Subject: [PATCH 39/62] fix cuda int --- backend/groth16/bn254/zeknox/zeknox.go | 114 ++++++++++++++++++++----- 1 file changed, 94 insertions(+), 20 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 5513684775..af085f7106 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "runtime" + "sync/atomic" "time" "unsafe" @@ -30,6 +31,11 @@ import ( "golang.org/x/sync/errgroup" ) +var g2_point_b_mont int32 = 0 +var g1_point_b_mont int32 = 0 +var g1_point_a_mont int32 = 0 +var g1_point_z_mont int32 = 0 + const HasZeknox = true // Use single GPU @@ -226,7 +232,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { - <- chWireValuesB + <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesB, chWireB); err != nil { @@ -235,7 +241,17 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, pk.G1Device.B, wireB); err != nil { + + val := atomic.LoadInt32(&g1_point_b_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_b_mont, 1) + mont = true + } + + if err := msmG1(&bs1, pk.G1Device.B, wireB, mont); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -246,7 +262,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { - <- chWireValuesA + <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) if err := CopyToDevice(wireValuesA, chWireA); err != nil { @@ -255,9 +271,20 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, pk.G1Device.A, wireA); err != nil { + + val := atomic.LoadInt32(&g1_point_a_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_a_mont, 1) + mont = true + } + + if err := msmG1(&ar, pk.G1Device.A, wireA, mont); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireA.Len()), time.Since(startAr)).Msg("ar done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) @@ -288,9 +315,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - if err := msmG1(&krs2, pk.G1Device.Z, deviceH); err != nil { + + val := atomic.LoadInt32(&g1_point_z_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g1_point_z_mont, 1) + mont = true + } + if err := msmG1(&krs2, pk.G1Device.Z, deviceH, mont); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2 done") return nil } @@ -329,9 +366,29 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, pk.G2Device.B, wireB); err != nil { + // scalar := onHost(wireValuesB[:]) + // point := onHost(pk.G2.B[:]) + // if err := msmG2(&Bs, &point, &scalar); err != nil { + // return err + // } + + val := atomic.LoadInt32(&g2_point_b_mont) + mont := true + if val == 1 { + mont = false + } else { + atomic.StoreInt32(&g2_point_b_mont, 1) + mont = true + } + + if err := msmG2(&Bs, pk.G2Device.B, wireB, mont); err != nil { + return err + } + + if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: 16}); err != nil { return err } + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") deltaS.FromAffine(&pk.G2.Delta) @@ -344,16 +401,21 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // Parallel execution, memory may hit limit - g, _ := errgroup.WithContext(context.TODO()) - g.Go(computeAR1) - g.Go(computeBS1) - g.Go(computeKRS1) - g.Go(computeKRS2) - g.Go(computeBS2) - - if err := g.Wait(); err != nil { - return nil, err - } + // g, _ := errgroup.WithContext(context.TODO()) + // g.Go(computeAR1) + computeAR1() + // g.Go(computeBS1) + computeBS1() + // g.Go(computeKRS1) + computeKRS1() + // g.Go(computeKRS2) + computeKRS2() + + // if err := g.Wait(); err != nil { + // return nil, err + // } + + computeBS2() // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { @@ -371,6 +433,12 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return proof, nil } +func onHost[T any](hostData []T) device.HostOrDeviceSlice[T] { + deviceSlice := device.NewEmpty[T]() + deviceSlice.OnHost(hostData) + return *deviceSlice +} + // if len(toRemove) == 0, returns slice // // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex @@ -464,11 +532,14 @@ func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *devi return nil } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element], input_point_in_mont bool) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = true + + cfg.AreInputPointInMont = input_point_in_mont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} @@ -479,11 +550,14 @@ func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], s return nil } -func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { +func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element], mont bool) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = true + + cfg.AreInputPointInMont = mont + cfg.AreOutputPointInMont = true + cfg.AreInputScalarInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} From 2dd91be13e5e62764586a6a42fd9cea193936173 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:20:06 +0800 Subject: [PATCH 40/62] delete unused deviceInfo --- backend/groth16/bn254/zeknox/provingkey.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index f54a97a669..4f66ec9661 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -1,10 +1,8 @@ package zeknox_bn254 import ( - "unsafe" - - groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" "github.com/consensys/gnark-crypto/ecc/bn254" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" cs "github.com/consensys/gnark/constraint/bn254" "github.com/okx/cryptography_cuda/wrappers/go/device" ) @@ -13,14 +11,9 @@ type deviceInfo struct { G1Device struct { A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] } - DomainDevice struct { - Twiddles, TwiddlesInv unsafe.Pointer - CosetTable, CosetTableInv unsafe.Pointer - } G2Device struct { B *device.HostOrDeviceSlice[bn254.G2Affine] } - DenDevice unsafe.Pointer InfinityPointIndicesK []int } From 5a3808678ea1044d57d2e03e14d177d634c4bf48 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 41/62] deviceInfo each points store ArePointsInMont --- backend/groth16/bn254/zeknox/provingkey.go | 26 ++++- backend/groth16/bn254/zeknox/zeknox.go | 105 +++++++-------------- 2 files changed, 59 insertions(+), 72 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index 4f66ec9661..fc58d89a00 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -9,10 +9,10 @@ import ( type deviceInfo struct { G1Device struct { - A, B, K, Z *device.HostOrDeviceSlice[bn254.G1Affine] + A, B, K, Z DevicePoints[bn254.G1Affine] } G2Device struct { - B *device.HostOrDeviceSlice[bn254.G2Affine] + B DevicePoints[bn254.G2Affine] } InfinityPointIndicesK []int } @@ -22,6 +22,14 @@ type ProvingKey struct { *deviceInfo } +type DevicePoints[T bn254.G1Affine | bn254.G2Affine] struct { + *device.HostOrDeviceSlice[T] + // Gnark points are in Montgomery form + // After 1 GPU MSM, points in GPU are converted to affine form + // Pass it to MSM config + ArePointsInMont bool +} + func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) } @@ -29,3 +37,17 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) } + +// You should call this method to free the GPU memory +// +// pk := groth16.NewProvingKey(ecc.BN254) +// defer pk.(*zeknox_bn254.ProvingKey).Free() +func (pk *ProvingKey) Free() { + if pk.deviceInfo != nil { + pk.deviceInfo.G1Device.A.Free() + pk.deviceInfo.G1Device.B.Free() + pk.deviceInfo.G1Device.K.Free() + pk.deviceInfo.G1Device.Z.Free() + pk.deviceInfo.G2Device.B.Free() + } +} \ No newline at end of file diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index af085f7106..9bf292d9f6 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -7,7 +7,6 @@ import ( "fmt" "math/big" "runtime" - "sync/atomic" "time" "unsafe" @@ -46,7 +45,6 @@ func (pk *ProvingKey) setupDevicePointers() error { return nil } pk.deviceInfo = &deviceInfo{} - // TODO: setup FFT // MSM G1 & G2 Device Setup g, _ := errgroup.WithContext(context.TODO()) @@ -84,11 +82,26 @@ func (pk *ProvingKey) setupDevicePointers() error { return err } // if no error, store device pointers in pk - pk.G1Device.A = <-deviceA - pk.G1Device.B = <-deviceG1B - pk.G1Device.K = <-deviceK - pk.G1Device.Z = <-deviceZ - pk.G2Device.B = <-deviceG2B + pk.G1Device.A = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceA, + ArePointsInMont: true, + } + pk.G1Device.B = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceG1B, + ArePointsInMont: true, + } + pk.G1Device.K = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceK, + ArePointsInMont: true, + } + pk.G1Device.Z = DevicePoints[curve.G1Affine]{ + HostOrDeviceSlice: <-deviceZ, + ArePointsInMont: true, + } + pk.G2Device.B = DevicePoints[curve.G2Affine]{ + HostOrDeviceSlice: <-deviceG2B, + ArePointsInMont: true, + } return nil } @@ -232,6 +245,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { + <-chWireValuesB <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -241,17 +255,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - - val := atomic.LoadInt32(&g1_point_b_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_b_mont, 1) - mont = true - } - - if err := msmG1(&bs1, pk.G1Device.B, wireB, mont); err != nil { + if err := msmG1(&bs1, &pk.G1Device.B, wireB); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -262,6 +266,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { + <-chWireValuesA <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -271,17 +276,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - - val := atomic.LoadInt32(&g1_point_a_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_a_mont, 1) - mont = true - } - - if err := msmG1(&ar, pk.G1Device.A, wireA, mont); err != nil { + if err := msmG1(&ar, &pk.G1Device.A, wireA); err != nil { return err } @@ -315,16 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - - val := atomic.LoadInt32(&g1_point_z_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g1_point_z_mont, 1) - mont = true - } - if err := msmG1(&krs2, pk.G1Device.Z, deviceH, mont); err != nil { + if err := msmG1(&krs2, &pk.G1Device.Z, deviceH); err != nil { return err } @@ -366,26 +352,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - // scalar := onHost(wireValuesB[:]) - // point := onHost(pk.G2.B[:]) - // if err := msmG2(&Bs, &point, &scalar); err != nil { - // return err - // } - - val := atomic.LoadInt32(&g2_point_b_mont) - mont := true - if val == 1 { - mont = false - } else { - atomic.StoreInt32(&g2_point_b_mont, 1) - mont = true - } - - if err := msmG2(&Bs, pk.G2Device.B, wireB, mont); err != nil { - return err - } - - if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: 16}); err != nil { + if err := msmG2(&Bs, &pk.G2Device.B, wireB); err != nil { return err } @@ -522,7 +489,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *device.HostOrDeviceSlice[S]) error { +func checkMsmInputs[P curve.G1Affine | curve.G2Affine](points *DevicePoints[P], scalars *device.HostOrDeviceSlice[fr.Element]) error { if !points.IsOnDevice() || !scalars.IsOnDevice() { return fmt.Errorf("MSM: points and scalars must be on device") } @@ -532,38 +499,36 @@ func checkMsmInputs[P, S any](points *device.HostOrDeviceSlice[P], scalars *devi return nil } -func msmG1(res *curve.G1Jac, points *device.HostOrDeviceSlice[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element], input_point_in_mont bool) error { +func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - - cfg.AreInputPointInMont = input_point_in_mont - cfg.AreInputScalarInMont = true - cfg.AreOutputPointInMont = true + cfg.ArePointsInMont = points.ArePointsInMont cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + // After 1 GPU MSM, points in GPU are converted to affine form + points.ArePointsInMont = false res.FromAffine(&resAffine) return nil } -func msmG2(res *curve.G2Jac, points *device.HostOrDeviceSlice[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element], mont bool) error { +func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - - cfg.AreInputPointInMont = mont - cfg.AreOutputPointInMont = true - cfg.AreInputScalarInMont = true + cfg.ArePointsInMont = points.ArePointsInMont cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { return err } + // After 1 GPU MSM, points in GPU are converted to affine form + points.ArePointsInMont = false res.FromAffine(&resAffine) return nil } From 9ca0eba74a2ddbfe1a79a67d37cf9b96019869da Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 42/62] update cuda library, verify GPU proof success! --- backend/groth16/bn254/zeknox/provingkey.go | 4 ++-- backend/groth16/bn254/zeknox/zeknox.go | 21 ++++++++++++--------- go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index fc58d89a00..854dad9d78 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -27,7 +27,7 @@ type DevicePoints[T bn254.G1Affine | bn254.G2Affine] struct { // Gnark points are in Montgomery form // After 1 GPU MSM, points in GPU are converted to affine form // Pass it to MSM config - ArePointsInMont bool + Mont bool } func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { @@ -50,4 +50,4 @@ func (pk *ProvingKey) Free() { pk.deviceInfo.G1Device.Z.Free() pk.deviceInfo.G2Device.B.Free() } -} \ No newline at end of file +} diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 9bf292d9f6..8fc4e7740d 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -84,23 +84,23 @@ func (pk *ProvingKey) setupDevicePointers() error { // if no error, store device pointers in pk pk.G1Device.A = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceA, - ArePointsInMont: true, + Mont: true, } pk.G1Device.B = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceG1B, - ArePointsInMont: true, + Mont: true, } pk.G1Device.K = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceK, - ArePointsInMont: true, + Mont: true, } pk.G1Device.Z = DevicePoints[curve.G1Affine]{ HostOrDeviceSlice: <-deviceZ, - ArePointsInMont: true, + Mont: true, } pk.G2Device.B = DevicePoints[curve.G2Affine]{ HostOrDeviceSlice: <-deviceG2B, - ArePointsInMont: true, + Mont: true, } return nil @@ -503,7 +503,8 @@ func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *devi checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = points.ArePointsInMont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G1Affine{} @@ -511,7 +512,7 @@ func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *devi return err } // After 1 GPU MSM, points in GPU are converted to affine form - points.ArePointsInMont = false + points.Mont = false res.FromAffine(&resAffine) return nil } @@ -520,7 +521,9 @@ func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *devi checkMsmInputs(points, scalars) cfg := msm.DefaultMSMConfig() cfg.AreInputsOnDevice = true - cfg.ArePointsInMont = points.ArePointsInMont + cfg.AreInputPointInMont = points.Mont + cfg.AreInputScalarInMont = true + cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 resAffine := curve.G2Affine{} @@ -528,7 +531,7 @@ func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *devi return err } // After 1 GPU MSM, points in GPU are converted to affine form - points.ArePointsInMont = false + points.Mont = false res.FromAffine(&resAffine) return nil } diff --git a/go.mod b/go.mod index 53ef90659c..d8bc4379f1 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 + github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 8da0c561d1..dc772d6de9 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26 h1:HgiJDIO/n8DTRCTRaw7CYm042Ieyo00O7wD90ZUteO0= -github.com/okx/cryptography_cuda v0.0.0-20241023025010-e04a13d4df26/go.mod h1:uoZvaCZ82rXfJuYz+hXCzDaMtts0zTGJt96rBqkoucQ= +github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 h1:rItWN8zYu0DqhyQvKfzmYVRUJmIJonqNP9N8WNhYoAo= +github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527/go.mod h1:Azb8uIJJqdXIq5np5A/RK2ga1sW1949bKFLg3yZCZjI= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 7ef565a85104ed9dc8a6f34d0d3a79c2b182a1f0 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:18 +0800 Subject: [PATCH 43/62] refactor msm, 1 msm func for both G1 and G2 --- backend/groth16/bn254/zeknox/zeknox.go | 75 ++++++++++++++------------ 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 8fc4e7740d..817b24cf63 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -255,7 +255,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs1 := time.Now() - if err := msmG1(&bs1, &pk.G1Device.B, wireB); err != nil { + if err := gpuMsm(&bs1, &pk.G1Device.B, wireB); err != nil { return err } log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", wireB.Len()), time.Since(startBs1)).Msg("bs1 done") @@ -276,7 +276,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireA = <-chWireA defer wireA.Free() startAr := time.Now() - if err := msmG1(&ar, &pk.G1Device.A, wireA); err != nil { + if err := gpuMsm(&ar, &pk.G1Device.A, wireA); err != nil { return err } @@ -310,7 +310,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b defer deviceH.Free() // MSM G1 Krs2 startKrs2 := time.Now() - if err := msmG1(&krs2, &pk.G1Device.Z, deviceH); err != nil { + if err := gpuMsm(&krs2, &pk.G1Device.Z, deviceH); err != nil { return err } @@ -352,7 +352,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b wireB = <-chWireB defer wireB.Free() startBs := time.Now() - if err := msmG2(&Bs, &pk.G2Device.B, wireB); err != nil { + if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { return err } @@ -489,50 +489,57 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func checkMsmInputs[P curve.G1Affine | curve.G2Affine](points *DevicePoints[P], scalars *device.HostOrDeviceSlice[fr.Element]) error { +// GPU Msm for either G1 or G2 points +func gpuMsm[R curve.G1Jac | curve.G2Jac, P curve.G1Affine | curve.G2Affine]( + res *R, + points *DevicePoints[P], + scalars *device.HostOrDeviceSlice[fr.Element], +) error { + // Check inputs if !points.IsOnDevice() || !scalars.IsOnDevice() { - return fmt.Errorf("MSM: points and scalars must be on device") + panic("points and scalars must be on device") } if points.Len() != scalars.Len() { - return fmt.Errorf("MSM: len(points) != len(scalars)") + panic("points and scalars should be in the same length") } - return nil -} -func msmG1(res *curve.G1Jac, points *DevicePoints[curve.G1Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - checkMsmInputs(points, scalars) + // Setup MSM config cfg := msm.DefaultMSMConfig() + cfg.AreInputPointInMont = points.Mont cfg.AreInputsOnDevice = true cfg.AreInputScalarInMont = true cfg.AreOutputPointInMont = true cfg.Npoints = uint32(points.Len()) cfg.LargeBucketFactor = 2 - resAffine := curve.G1Affine{} - if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { - return err - } - // After 1 GPU MSM, points in GPU are converted to affine form - points.Mont = false - res.FromAffine(&resAffine) - return nil -} -func msmG2(res *curve.G2Jac, points *DevicePoints[curve.G2Affine], scalars *device.HostOrDeviceSlice[fr.Element]) error { - checkMsmInputs(points, scalars) - cfg := msm.DefaultMSMConfig() - cfg.AreInputsOnDevice = true - cfg.AreInputPointInMont = points.Mont - cfg.AreInputScalarInMont = true - cfg.AreOutputPointInMont = true - cfg.Npoints = uint32(points.Len()) - cfg.LargeBucketFactor = 2 - resAffine := curve.G2Affine{} - if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { - return err + switch any(points).(type) { + case *DevicePoints[curve.G1Affine]: + resAffine := curve.G1Affine{} + if err := msm.MSM_G1(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G1Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G1Jac") + } + case *DevicePoints[curve.G2Affine]: + resAffine := curve.G2Affine{} + if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { + return err + } + if r, ok := any(res).(*curve.G2Jac); ok { + r.FromAffine(&resAffine) + } else { + panic("res type should be *curve.G2Jac") + } + default: + panic("invalid points type") } - // After 1 GPU MSM, points in GPU are converted to affine form + + // After GPU MSM, points in GPU are converted to affine form points.Mont = false - res.FromAffine(&resAffine) + return nil } From c0073c98b2acfc26d3ca40e576ae24dd4a8aa22d Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:42:55 +0800 Subject: [PATCH 44/62] parallel msm, sometimes verify fail --- backend/groth16/bn254/zeknox/zeknox.go | 41 ++++++++------------------ 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 817b24cf63..2a3bfa89bb 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -30,11 +30,6 @@ import ( "golang.org/x/sync/errgroup" ) -var g2_point_b_mont int32 = 0 -var g1_point_b_mont int32 = 0 -var g1_point_a_mont int32 = 0 -var g1_point_z_mont int32 = 0 - const HasZeknox = true // Use single GPU @@ -245,7 +240,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac computeBS1 := func() error { - <-chWireValuesB <-chWireValuesB var wireB *device.HostOrDeviceSlice[fr.Element] chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -266,7 +260,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } computeAR1 := func() error { - <-chWireValuesA <-chWireValuesA var wireA *device.HostOrDeviceSlice[fr.Element] chWireA := make(chan *device.HostOrDeviceSlice[fr.Element], 1) @@ -368,19 +361,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // Parallel execution, memory may hit limit - // g, _ := errgroup.WithContext(context.TODO()) - // g.Go(computeAR1) - computeAR1() - // g.Go(computeBS1) - computeBS1() - // g.Go(computeKRS1) - computeKRS1() - // g.Go(computeKRS2) - computeKRS2() - - // if err := g.Wait(); err != nil { - // return nil, err - // } + g, _ := errgroup.WithContext(context.TODO()) + g.Go(computeAR1) + g.Go(computeBS1) + g.Go(computeKRS1) + g.Go(computeKRS2) + + if err := g.Wait(); err != nil { + return nil, err + } computeBS2() @@ -400,12 +389,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return proof, nil } -func onHost[T any](hostData []T) device.HostOrDeviceSlice[T] { - deviceSlice := device.NewEmpty[T]() - deviceSlice.OnHost(hostData) - return *deviceSlice -} - // if len(toRemove) == 0, returns slice // // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex @@ -519,10 +502,10 @@ func gpuMsm[R curve.G1Jac | curve.G2Jac, P curve.G1Affine | curve.G2Affine]( return err } if r, ok := any(res).(*curve.G1Jac); ok { - r.FromAffine(&resAffine) - } else { + r.FromAffine(&resAffine) + } else { panic("res type should be *curve.G1Jac") - } + } case *DevicePoints[curve.G2Affine]: resAffine := curve.G2Affine{} if err := msm.MSM_G2(unsafe.Pointer(&resAffine), points.AsPtr(), scalars.AsPtr(), deviceId, cfg); err != nil { From c942abaca6f806d980a91cff2ac614eca15a4c35 Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:55:56 +0800 Subject: [PATCH 45/62] parallel + copy point every time --- backend/groth16/bn254/zeknox/zeknox.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 2a3bfa89bb..760a51a697 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -114,6 +114,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + if pk.deviceInfo != nil { + pk.Free() + pk.deviceInfo = nil + } if pk.deviceInfo == nil { start := time.Now() if err := pk.setupDevicePointers(); err != nil { @@ -366,12 +370,17 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b g.Go(computeBS1) g.Go(computeKRS1) g.Go(computeKRS2) - + g.Go(computeBS2) if err := g.Wait(); err != nil { return nil, err } - computeBS2() + // Serial execution + // computeAR1() + // computeBS1() + // computeKRS1() + // computeKRS2() + // computeBS2() // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { From 7826d9c5fce84fc8861b2b8772b4c27a348b743e Mon Sep 17 00:00:00 2001 From: "Jason.Huang" <20609724+doutv@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:00:03 +0800 Subject: [PATCH 46/62] serial GPU msm, always success --- backend/groth16/bn254/zeknox/zeknox.go | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 760a51a697..2195c231c4 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -114,10 +114,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) } log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "zeknox").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() - if pk.deviceInfo != nil { - pk.Free() - pk.deviceInfo = nil - } if pk.deviceInfo == nil { start := time.Now() if err := pk.setupDevicePointers(); err != nil { @@ -364,24 +360,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil } - // Parallel execution, memory may hit limit g, _ := errgroup.WithContext(context.TODO()) - g.Go(computeAR1) - g.Go(computeBS1) + // CPU MSM g.Go(computeKRS1) - g.Go(computeKRS2) - g.Go(computeBS2) + + // Serial GPU MSM + computeAR1() + computeBS1() + computeKRS2() + computeBS2() if err := g.Wait(); err != nil { return nil, err } - // Serial execution - // computeAR1() - // computeBS1() - // computeKRS1() - // computeKRS2() - // computeBS2() - // FinalKRS = KRS1 + KRS2 + s*AR + r*BS1 { var p1 curve.G1Jac From e067460058b9f1ec51ce1ec06ca9288efe663e7f Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 5 Nov 2024 18:56:20 +0800 Subject: [PATCH 47/62] small improvement in zeknox prover --- backend/groth16/bn254/zeknox/provingkey.go | 2 +- backend/groth16/bn254/zeknox/zeknox.go | 92 ++++++++++++++++++---- go.mod | 4 +- go.sum | 2 - 4 files changed, 79 insertions(+), 21 deletions(-) diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index 854dad9d78..fd6966e294 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -4,7 +4,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" cs "github.com/consensys/gnark/constraint/bn254" - "github.com/okx/cryptography_cuda/wrappers/go/device" + "github.com/okx/zeknox/wrappers/go/device" ) type deviceInfo struct { diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 2195c231c4..6b046cdd42 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -25,8 +25,8 @@ import ( fcs "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" - "github.com/okx/cryptography_cuda/wrappers/go/device" - "github.com/okx/cryptography_cuda/wrappers/go/msm" + "github.com/okx/zeknox/wrappers/go/device" + "github.com/okx/zeknox/wrappers/go/msm" "golang.org/x/sync/errgroup" ) @@ -62,6 +62,7 @@ func (pk *ProvingKey) setupDevicePointers() error { } deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) + // g.Go(func() error { return CopyToDevice(pk.G1.K, deviceK) }) // G1.Z deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) @@ -312,7 +313,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } var krs1 curve.G1Jac - computeKRS1 := func() error { + + computeKRS1_CPU := func() error { // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() @@ -332,24 +334,80 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil } - computeBS2 := func() error { - <-chWireValuesB + /* + computeKRS1 := func() error { + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + // original Groth16 witness without pedersen commitment + wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + + var deviceWireValuesWithoutCom *device.HostOrDeviceSlice[fr.Element] + chDeviceW := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + sizeW := len(wireValuesWithoutCom) + // copy to GPU + if err := CopyToDevice(wireValuesWithoutCom[:sizeW], chDeviceW); err != nil { + return err + } + deviceWireValuesWithoutCom = <-chDeviceW + defer deviceWireValuesWithoutCom.Free() + // on GPU + startKrs := time.Now() + if err := gpuMsm(&krs1, &pk.G1Device.K, deviceWireValuesWithoutCom); err != nil { + return err + } + log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("GPU krs1 done") + // -rs[δ] + krs1.AddMixed(&deltas[2]) + return nil + } + */ + /* + computeBS2 := func() error { + <-chWireValuesB + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac + + var wireB *device.HostOrDeviceSlice[fr.Element] + chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) + if err := CopyToDevice(wireValuesB, chWireB); err != nil { + return err + } + wireB = <-chWireB + defer wireB.Free() + startBs := time.Now() + if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { + return err + } + + log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") + + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) + + proof.Bs.FromJacobian(&Bs) + return nil + } + */ + + computeBS2_CPU := func() error { // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac - var wireB *device.HostOrDeviceSlice[fr.Element] - chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesB, chWireB); err != nil { - return err + nbTasks := runtime.NumCPU() / 2 + if nbTasks <= 16 { + // if we don't have a lot of CPUs, this may artificially split the MSM + nbTasks *= 2 } - wireB = <-chWireB - defer wireB.Free() + <-chWireValuesB startBs := time.Now() - if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { + if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { return err } - - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") + log.Debug().Dur(fmt.Sprintf("MSMG2 %d took", len(wireValuesB)), time.Since(startBs)).Msg("Bs.MultiExp done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -362,13 +420,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b g, _ := errgroup.WithContext(context.TODO()) // CPU MSM - g.Go(computeKRS1) + g.Go(computeKRS1_CPU) + g.Go(computeBS2_CPU) // Serial GPU MSM computeAR1() computeBS1() + // computeKRS1() computeKRS2() - computeBS2() + // computeBS2() if err := g.Wait(); err != nil { return nil, err } diff --git a/go.mod b/go.mod index d8bc4379f1..5b689029bf 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 + github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 @@ -38,4 +38,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/okx/cryptography_cuda/wrappers/go => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda/wrappers/go +// replace github.com/okx/cryptography_cuda => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda diff --git a/go.sum b/go.sum index dc772d6de9..81dc4a651b 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,6 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527 h1:rItWN8zYu0DqhyQvKfzmYVRUJmIJonqNP9N8WNhYoAo= -github.com/okx/cryptography_cuda v0.0.0-20241023112133-1756b0ee9527/go.mod h1:Azb8uIJJqdXIq5np5A/RK2ga1sW1949bKFLg3yZCZjI= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 80e74c1cdfc91ab521b199aab288360ddd736c39 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 21 Nov 2024 10:44:50 +0800 Subject: [PATCH 48/62] update zeknox to v1.0.0 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 5b689029bf..175d910796 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/icza/bitio v1.1.0 github.com/ingonyama-zk/iciclegnark v0.1.0 github.com/leanovate/gopter v0.2.11 - github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 + github.com/okx/zeknox v1.0.0 github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 81dc4a651b..62b4183cd9 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/okx/zeknox v1.0.0 h1:W/nZnaBIQjB5LHK2DsVdlL5v7NHP4OBPLziZ6gh/14U= +github.com/okx/zeknox v1.0.0/go.mod h1:zlHemJhkN7W22xWWtANF66oPdzUJYT1frlkdSZhLQbc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 7fbb8ed1e81f3f6c08c4270f8706b4e1da2a4732 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:42:56 +0800 Subject: [PATCH 49/62] disable zeknox by default --- backend/groth16/bn254/zeknox/nozeknox.go | 2 +- backend/groth16/bn254/zeknox/zeknox.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index 8859d6f319..a1c94bb97b 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -1,4 +1,4 @@ -//go:build zeknox +//go:build !zeknox package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 6b046cdd42..a7660f0671 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -1,4 +1,4 @@ -//go:build !zeknox +//go:build zeknox package zeknox_bn254 From bb7d86109a19069797e1323f5370d68f1ad72eee Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:51:19 +0800 Subject: [PATCH 50/62] delete comment --- go.mod | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.mod b/go.mod index 175d910796..276c7e61c3 100644 --- a/go.mod +++ b/go.mod @@ -37,5 +37,3 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) - -// replace github.com/okx/cryptography_cuda => /home/okxdex/data/zkdex-pap/workspace/jason-huang/cryptography_cuda From e26f6d13107940e18110aa30c1db57c1a341f3c7 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:52:45 +0800 Subject: [PATCH 51/62] restore --- backend/groth16/bn254/prove.go | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 6b58202c60..5f0d413133 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -138,7 +138,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil, err } - // quotient poly H (witness reduction / FFT part) + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { @@ -186,8 +186,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b if _, err := _s.SetRandom(); err != nil { return nil, err } - // -rs - // Why it is called kr? not rs? -> notation from DIZK paper _kr.Mul(&_r, &_s).Neg(&_kr) _r.BigInt(&r) @@ -203,14 +201,11 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chBs1Done := make(chan error, 1) computeBS1 := func() { <-chWireValuesB - startBs1 := time.Now() if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chBs1Done <- err close(chBs1Done) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValuesB)), time.Since(startBs1)).Msg("bs1.MultiExp done") - // + beta + s[δ] bs1.AddMixed(&pk.G1.Beta) bs1.AddMixed(&deltas[1]) chBs1Done <- nil @@ -219,13 +214,11 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chArDone := make(chan error, 1) computeAR1 := func() { <-chWireValuesA - startAr := time.Now() if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chArDone <- err close(chArDone) return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValuesA)), time.Since(startAr)).Msg("ar.MultiExp done") ar.AddMixed(&pk.G1.Alpha) ar.AddMixed(&deltas[0]) proof.Ar.FromJacobian(&ar) @@ -241,9 +234,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chKrs2Done := make(chan error, 1) sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - startKrs2 := time.Now() _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", sizeH), time.Since(startKrs2)).Msg("krs2.MultiExp done") chKrs2Done <- err }() @@ -253,13 +244,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - startKrs := time.Now() if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(_wireValues)), time.Since(startKrs)).Msg("krs.MultiExp done") - // -rs[δ] krs.AddMixed(&deltas[2]) n := 3 for n != 0 { @@ -302,11 +290,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b nbTasks *= 2 } <-chWireValuesB - startBs := time.Now() if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { return err } - log.Debug().Dur(fmt.Sprintf("MSMG2 %d took", len(wireValuesB)), time.Since(startBs)).Msg("Bs.MultiExp done") deltaS.FromAffine(&pk.G2.Delta) deltaS.ScalarMultiplication(&deltaS, &s) @@ -383,27 +369,19 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - // a -> aPoly, b -> bPoly, c -> cPoly - // point-value form -> coefficient form domain.FFTInverse(a, fft.DIF) domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - // evaluate aPoly, bPoly, cPoly on coset (roots of unity) domain.FFT(a, fft.DIT, fft.OnCoset()) domain.FFT(b, fft.DIT, fft.OnCoset()) domain.FFT(c, fft.DIT, fft.OnCoset()) - // vanishing poly t(x) = x^N - 1 - // calcualte 1/t(g), where g is the generator var den, one fr.Element one.SetOne() - // g^N den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) - // 1/(g^N - 1) den.Sub(&den, &one).Inverse(&den) - // h = (a*b - c)/t // h = ifft_coset(ca o cb - cc) // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { @@ -414,7 +392,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { } }) - // ifft_coset: point-value form -> coefficient form + // ifft_coset domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a From 4e4fef85decfddc1840bd57587478c110341e595 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:08:35 +0800 Subject: [PATCH 52/62] rename to zeknox --- README.md | 4 ++-- backend/backend.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0ac1b35c10..b91f030d91 100644 --- a/README.md +++ b/README.md @@ -162,14 +162,14 @@ func main() { ### GPU Support #### Zeknox Library -Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/cryptography_cuda) +Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/zeknox) ##### Download prebuilt binary ```sh sudo cp libblst.a libcryptocuda.a /usr/local/lib/ ``` -If you want to build from source, see guide in https://github.com/okx/cryptography_cuda +If you want to build from source, see guide in https://github.com/okx/zeknox ##### Enjoy GPU Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` diff --git a/backend/backend.go b/backend/backend.go index 8b694f2bce..52da9c5671 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -138,7 +138,7 @@ func WithProverKZGFoldingHashFunction(hFunc hash.Hash) ProverOption { // tag and the ZEKNOX dependencies are properly installed. See [ZEKNOX] for // installation description. // -// [ZEKNOX]: https://github.com/okx/cryptography_cuda +// [ZEKNOX]: https://github.com/okx/zeknox func WithZeknoxAcceleration() ProverOption { return func(pc *ProverConfig) error { pc.Accelerator = "zeknox" From 41f319258dd814d324f57c91aef2aacca4a2ec8d Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:37:19 +0800 Subject: [PATCH 53/62] add guide in readme --- README.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index b91f030d91..9357181b3c 100644 --- a/README.md +++ b/README.md @@ -164,17 +164,24 @@ func main() { #### Zeknox Library Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/zeknox) -##### Download prebuilt binary -```sh -sudo cp libblst.a libcryptocuda.a /usr/local/lib/ -``` +Build from source, see guide in https://github.com/okx/zeknox -If you want to build from source, see guide in https://github.com/okx/zeknox +##### Run +```go +// main.go +groth16.Prove(r1cs, pk, witnessData) +// -> +groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) +``` -##### Enjoy GPU -Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` +```sh +# run with zeknox build tag +go run -tags=zeknox main.go +# (place -tags before the filename) +``` -Test +##### Test +add the following to the [mimc test](examples/mimc/mimc_test.go) ```go assert.ProverSucceeded(&mimcCircuit, &Circuit{ PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", @@ -183,12 +190,11 @@ assert.ProverSucceeded(&mimcCircuit, &Circuit{ ``` ```sh -go run -tags=zeknox examples/main.go -# (place -tags before the filename) - +# test with zeknox build tag go test github.com/consensys/gnark/examples/mimc -tags=prover_checks,zeknox ``` + #### Icicle Library The following schemes and curves support experimental use of Ingonyama's Icicle GPU library for low level zk-SNARK primitives such as MSM, NTT, and polynomial operations: From ea037dd92601b456f5603eeb8cc16ac7abbc7081 Mon Sep 17 00:00:00 2001 From: Dumitrel Loghin Date: Wed, 27 Nov 2024 17:46:45 +0800 Subject: [PATCH 54/62] add zeknox build tag --- backend/groth16/bn254/zeknox/nozeknox.go | 2 +- backend/groth16/bn254/zeknox/provingkey.go | 2 + .../bn254/zeknox/provingkey_nozeknox.go | 39 +++++++++++++++++++ backend/groth16/bn254/zeknox/zeknox.go | 2 +- 4 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 backend/groth16/bn254/zeknox/provingkey_nozeknox.go diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index 8859d6f319..a1c94bb97b 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -1,4 +1,4 @@ -//go:build zeknox +//go:build !zeknox package zeknox_bn254 diff --git a/backend/groth16/bn254/zeknox/provingkey.go b/backend/groth16/bn254/zeknox/provingkey.go index fd6966e294..32f80dfa51 100644 --- a/backend/groth16/bn254/zeknox/provingkey.go +++ b/backend/groth16/bn254/zeknox/provingkey.go @@ -1,3 +1,5 @@ +//go:build zeknox + package zeknox_bn254 import ( diff --git a/backend/groth16/bn254/zeknox/provingkey_nozeknox.go b/backend/groth16/bn254/zeknox/provingkey_nozeknox.go new file mode 100644 index 0000000000..5824af72d1 --- /dev/null +++ b/backend/groth16/bn254/zeknox/provingkey_nozeknox.go @@ -0,0 +1,39 @@ +package zeknox_bn254 + +import ( + "unsafe" + + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + cs "github.com/consensys/gnark/constraint/bn254" +) + +type deviceInfo struct { + G1Device struct { + A, B, K, Z unsafe.Pointer + } + DomainDevice struct { + Twiddles, TwiddlesInv unsafe.Pointer + CosetTable, CosetTableInv unsafe.Pointer + } + G2Device struct { + B unsafe.Pointer + } + DenDevice unsafe.Pointer + InfinityPointIndicesK []int +} + +type ProvingKey struct { + groth16_bn254.ProvingKey + *deviceInfo +} + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} + +func (pk *ProvingKey) Free() { +} diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index 6b046cdd42..a7660f0671 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -1,4 +1,4 @@ -//go:build !zeknox +//go:build zeknox package zeknox_bn254 From 70957a082b2b46cddd23831b5ffc3b64a6ab3851 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 27 Nov 2024 17:49:22 +0800 Subject: [PATCH 55/62] add zeknox build tag --- backend/groth16/bn254/zeknox/provingkey_nozeknox.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/groth16/bn254/zeknox/provingkey_nozeknox.go b/backend/groth16/bn254/zeknox/provingkey_nozeknox.go index 5824af72d1..6d55013b8f 100644 --- a/backend/groth16/bn254/zeknox/provingkey_nozeknox.go +++ b/backend/groth16/bn254/zeknox/provingkey_nozeknox.go @@ -1,3 +1,5 @@ +//go:build !zeknox + package zeknox_bn254 import ( From 8ff71e75a6044446f2bfe092c73e21758ccbea16 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 27 Nov 2024 19:01:14 +0800 Subject: [PATCH 56/62] add zeknox build tag --- backend/groth16/bn254/zeknox/nozeknox.go | 15 +++++++ .../bn254/zeknox/provingkey_nozeknox.go | 41 ------------------- examples/p256/p256.go | 4 +- 3 files changed, 18 insertions(+), 42 deletions(-) delete mode 100644 backend/groth16/bn254/zeknox/provingkey_nozeknox.go diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index a1c94bb97b..0058078cfe 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -11,8 +11,23 @@ import ( cs "github.com/consensys/gnark/constraint/bn254" ) +type ProvingKey struct { + groth16_bn254.ProvingKey +} + const HasZeknox = false func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { return nil, fmt.Errorf("zeknox backend requested but program compiled without 'zeknox' build tag") } + +func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { + return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) +} + +func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { + return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) +} + +func (pk *ProvingKey) Free() { +} diff --git a/backend/groth16/bn254/zeknox/provingkey_nozeknox.go b/backend/groth16/bn254/zeknox/provingkey_nozeknox.go deleted file mode 100644 index 6d55013b8f..0000000000 --- a/backend/groth16/bn254/zeknox/provingkey_nozeknox.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build !zeknox - -package zeknox_bn254 - -import ( - "unsafe" - - groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" - cs "github.com/consensys/gnark/constraint/bn254" -) - -type deviceInfo struct { - G1Device struct { - A, B, K, Z unsafe.Pointer - } - DomainDevice struct { - Twiddles, TwiddlesInv unsafe.Pointer - CosetTable, CosetTableInv unsafe.Pointer - } - G2Device struct { - B unsafe.Pointer - } - DenDevice unsafe.Pointer - InfinityPointIndicesK []int -} - -type ProvingKey struct { - groth16_bn254.ProvingKey - *deviceInfo -} - -func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *groth16_bn254.VerifyingKey) error { - return groth16_bn254.Setup(r1cs, &pk.ProvingKey, vk) -} - -func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { - return groth16_bn254.DummySetup(r1cs, &pk.ProvingKey) -} - -func (pk *ProvingKey) Free() { -} diff --git a/examples/p256/p256.go b/examples/p256/p256.go index a6a34baac7..1d9fb5c53b 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -149,7 +149,9 @@ func Groth16Prove(fileDir string) { // read zkey start = time.Now() pk := groth16.NewProvingKey(ecc.BN254) - defer pk.(*zeknox_bn254.ProvingKey).Free() + if zeknox_bn254.HasZeknox { + defer pk.(*zeknox_bn254.ProvingKey).Free() + } UnsafeReadFromFile(pk, fileDir+circuitName+".zkey") elapsed = time.Since(start) log.Printf("Read zkey: %d ms", elapsed.Milliseconds()) From 8fccf0cdc5f7830d4d74066d08c3a8bdf9a595e4 Mon Sep 17 00:00:00 2001 From: Dumitrel Loghin Date: Thu, 28 Nov 2024 12:04:22 +0800 Subject: [PATCH 57/62] update readme --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 92786166c4..fa97d518d4 100644 --- a/README.md +++ b/README.md @@ -162,14 +162,16 @@ func main() { ### GPU Support #### Zeknox Library -Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/cryptography_cuda) +Unlock free GPU acceleration with [OKX Zeknox library](https://github.com/okx/zeknox) ##### Download prebuilt binary ```sh -sudo cp libblst.a libcryptocuda.a /usr/local/lib/ +curl -L -o libzeknox.a https://github.com/okx/zeknox/releases/download/v1.0.0/bn254-msm-86-89-90-libzeknox.a +curl -L -o libblst.a https://github.com/okx/zeknox/releases/download/v1.0.0/libblst.a +sudo cp libblst.a libzeknox.a /usr/local/lib/ ``` -If you want to build from source, see guide in https://github.com/okx/cryptography_cuda +If you want to build from source, see guide in https://github.com/okx/zeknox ##### Enjoy GPU Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` From b1dc00f204e7dfcc3cbfaca3413fb5d428131473 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 29 Nov 2024 11:32:40 +0800 Subject: [PATCH 58/62] clean source code --- backend/groth16/bn254/zeknox/nozeknox.go | 1 + backend/groth16/bn254/zeknox/zeknox.go | 61 ------------------------ examples/p256/p256.go | 17 +++---- go.sum | 2 - 4 files changed, 7 insertions(+), 74 deletions(-) diff --git a/backend/groth16/bn254/zeknox/nozeknox.go b/backend/groth16/bn254/zeknox/nozeknox.go index 0058078cfe..b920f92c00 100644 --- a/backend/groth16/bn254/zeknox/nozeknox.go +++ b/backend/groth16/bn254/zeknox/nozeknox.go @@ -30,4 +30,5 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { } func (pk *ProvingKey) Free() { + // nothing to do here } diff --git a/backend/groth16/bn254/zeknox/zeknox.go b/backend/groth16/bn254/zeknox/zeknox.go index a7660f0671..52ece00523 100644 --- a/backend/groth16/bn254/zeknox/zeknox.go +++ b/backend/groth16/bn254/zeknox/zeknox.go @@ -46,7 +46,6 @@ func (pk *ProvingKey) setupDevicePointers() error { // G1.A deviceA := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) g.Go(func() error { return CopyToDevice(pk.G1.A, deviceA) }) - // G1.B deviceG1B := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) g.Go(func() error { return CopyToDevice(pk.G1.B, deviceG1B) }) @@ -62,7 +61,6 @@ func (pk *ProvingKey) setupDevicePointers() error { } deviceK := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) g.Go(func() error { return CopyToDevice(pointsNoInfinity, deviceK) }) - // g.Go(func() error { return CopyToDevice(pk.G1.K, deviceK) }) // G1.Z deviceZ := make(chan *device.HostOrDeviceSlice[curve.G1Affine], 1) @@ -334,65 +332,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil } - /* - computeKRS1 := func() error { - // filter the wire values if needed - // TODO Perf @Tabaie worst memory allocation offender - toRemove := commitmentInfo.GetPrivateCommitted() - toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) - // original Groth16 witness without pedersen commitment - wireValuesWithoutCom := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - - var deviceWireValuesWithoutCom *device.HostOrDeviceSlice[fr.Element] - chDeviceW := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - sizeW := len(wireValuesWithoutCom) - // copy to GPU - if err := CopyToDevice(wireValuesWithoutCom[:sizeW], chDeviceW); err != nil { - return err - } - deviceWireValuesWithoutCom = <-chDeviceW - defer deviceWireValuesWithoutCom.Free() - // on GPU - startKrs := time.Now() - if err := gpuMsm(&krs1, &pk.G1Device.K, deviceWireValuesWithoutCom); err != nil { - return err - } - log.Debug().Dur(fmt.Sprintf("MSMG1 %d took", len(wireValues)), time.Since(startKrs)).Msg("GPU krs1 done") - // -rs[δ] - krs1.AddMixed(&deltas[2]) - return nil - } - */ - /* - computeBS2 := func() error { - <-chWireValuesB - // Bs2 (1 multi exp G2 - size = len(wires)) - var Bs, deltaS curve.G2Jac - - var wireB *device.HostOrDeviceSlice[fr.Element] - chWireB := make(chan *device.HostOrDeviceSlice[fr.Element], 1) - if err := CopyToDevice(wireValuesB, chWireB); err != nil { - return err - } - wireB = <-chWireB - defer wireB.Free() - startBs := time.Now() - if err := gpuMsm(&Bs, &pk.G2Device.B, wireB); err != nil { - return err - } - - log.Debug().Dur(fmt.Sprintf("MSMG2 %v took", wireB.Len()), time.Since(startBs)).Msg("Bs done") - - deltaS.FromAffine(&pk.G2.Delta) - deltaS.ScalarMultiplication(&deltaS, &s) - Bs.AddAssign(&deltaS) - Bs.AddMixed(&pk.G2.Beta) - - proof.Bs.FromJacobian(&Bs) - return nil - } - */ - computeBS2_CPU := func() error { // Bs2 (1 multi exp G2 - size = len(wires)) var Bs, deltaS curve.G2Jac diff --git a/examples/p256/p256.go b/examples/p256/p256.go index 1d9fb5c53b..caf38e5cc1 100644 --- a/examples/p256/p256.go +++ b/examples/p256/p256.go @@ -4,7 +4,6 @@ import ( cryptoecdsa "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "fmt" "io" "log" "math/big" @@ -73,22 +72,20 @@ func generateWitnessCircuit() EcdsaCircuit[emulated.P256Fp, emulated.P256Fr] { !inner.ReadASN1Integer(r) || !inner.ReadASN1Integer(s) || !inner.Empty() { - panic("invalid sig") + log.Panicf("Invalid signature.") } flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) if !flag { - println("can't verify signature") + log.Panicf("Can't verify signature.") } // hashIn += Pub[i].X + Pub[i].Y + Msg[i] pubX := publicKey.X.Bytes() pubY := publicKey.Y.Bytes() - // println("pubX:", hex.EncodeToString(pubX)) - // println("pubY:", hex.EncodeToString(pubY)) - // println("msgHash:", hex.EncodeToString(msgHash[:])) hashIn = append(hashIn, pubX[:]...) hashIn = append(hashIn, pubY[:]...) hashIn = append(hashIn, msgHash[:]...) + // Assign to circuit witness witness.Sig[i] = Signature[emulated.P256Fr]{ R: emulated.ValueOf[emulated.P256Fr](r), @@ -102,7 +99,6 @@ func generateWitnessCircuit() EcdsaCircuit[emulated.P256Fp, emulated.P256Fr] { } hashOut := keccak256(hashIn) hashOut[0] = 0 // ignore the first byte, since BN254 order < uint256 - // println("hashOut:", hex.EncodeToString(hashOut[:])) witness.Commitment = hashOut[:] return witness } @@ -162,7 +158,7 @@ func Groth16Prove(fileDir string) { // CPU for i := 0; i < 1; i++ { - fmt.Printf("------ CPU Prove %d ------", i+1) + log.Printf("------ CPU Prove %d ------", i+1) witnessData, err := generateWitness() if err != nil { panic(err) @@ -183,7 +179,7 @@ func Groth16Prove(fileDir string) { // GPU for i := 0; i < 1; i++ { - fmt.Printf("------ GPU Prove %d ------\n", i+1) + log.Printf("------ GPU Prove %d ------\n", i+1) witnessData, err := generateWitness() if err != nil { panic(err) @@ -198,8 +194,7 @@ func Groth16Prove(fileDir string) { panic(err) } if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { - fmt.Printf("\nError in GPU Verify %d: %s\n\n", i+1, err) - // panic(err) + log.Panicf("\nError in GPU Verify %d: %s\n\n", i+1, err) } } } diff --git a/go.sum b/go.sum index 198c8b862c..21ce300c0e 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,6 @@ github.com/consensys/bavard v0.1.22 h1:Uw2CGvbXSZWhqK59X0VG/zOjpTFuOMcPLStrp1ihI github.com/consensys/bavard v0.1.22/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.14.1-0.20241010154951-6638408a49f3 h1:jVatckGR1s3OHs4QnGsppX+w2P3eedlWxi7ZFq56rjA= -github.com/consensys/gnark-crypto v0.14.1-0.20241010154951-6638408a49f3/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0 h1:uFZaZWG0FOoiFN3fAQzH2JXDuybdNwiJzBujy81YtU4= github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= From b7da2aed12c055ef6a38cb65c00d6b143a589206 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Tue, 10 Dec 2024 18:41:55 +0800 Subject: [PATCH 59/62] revert examples --- examples/main.go | 14 --- examples/mimc/mimc_test.go | 7 -- examples/p256/circuit.go | 111 ---------------- examples/p256/p256.go | 252 ------------------------------------- examples/p256/p256_test.go | 17 --- 5 files changed, 401 deletions(-) delete mode 100644 examples/main.go delete mode 100644 examples/p256/circuit.go delete mode 100644 examples/p256/p256.go delete mode 100644 examples/p256/p256_test.go diff --git a/examples/main.go b/examples/main.go deleted file mode 100644 index 1326e8e7a1..0000000000 --- a/examples/main.go +++ /dev/null @@ -1,14 +0,0 @@ -package main - -import ( - "os" - - "github.com/consensys/gnark/examples/p256" -) - -func main() { - if _, err := os.Stat("build/"); os.IsNotExist(err) { - p256.Groth16Setup("build/") - } - p256.Groth16Prove("build/") -} diff --git a/examples/mimc/mimc_test.go b/examples/mimc/mimc_test.go index 7631974ad8..5583193bf6 100644 --- a/examples/mimc/mimc_test.go +++ b/examples/mimc/mimc_test.go @@ -18,11 +18,9 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/test" ) -// go test github.com/consensys/gnark/examples/mimc -tags=prover_checks func TestPreimage(t *testing.T) { assert := test.NewAssert(t) @@ -38,9 +36,4 @@ func TestPreimage(t *testing.T) { Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", }, test.WithCurves(ecc.BN254)) - assert.ProverSucceeded(&mimcCircuit, &Circuit{ - PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", - Hash: "12886436712380113721405259596386800092738845035233065858332878701083870690753", - }, test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) - } diff --git a/examples/p256/circuit.go b/examples/p256/circuit.go deleted file mode 100644 index c6613beb89..0000000000 --- a/examples/p256/circuit.go +++ /dev/null @@ -1,111 +0,0 @@ -package p256 - -import ( - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" - "github.com/consensys/gnark/std/hash/sha3" - "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/std/math/uints" -) - -type EcdsaCircuit[T, S emulated.FieldParams] struct { - Commitment frontend.Variable `gnark:",public"` // Keccak256(Pub[0], Msg[0], Sig[1], Msg[1], ...)[1:32], ignore the first byte, since BN254 order < uint256 - - Pub [NumSignatures]PublicKey[T, S] `gnark:",secret"` - Msg [NumSignatures]emulated.Element[S] `gnark:",secret"` - Sig [NumSignatures]Signature[S] `gnark:",secret"` -} - -func (c *EcdsaCircuit[T, S]) Define(api frontend.API) error { - // Verify all ECDSA-P256 signatures - for i := range c.Sig { - c.Pub[i].Verify(api, sw_emulated.GetCurveParams[T](), &c.Msg[i], &c.Sig[i]) - } - // Keccak256 Commit to all signatures - h, err := sha3.NewLegacyKeccak256(api) - if err != nil { - return err - } - uapi, err := uints.New[uints.U64](api) - if err != nil { - return err - } - - var tInstance T - var sInstance S - perSignatureHashSize := 2*tInstance.NbLimbs() + sInstance.NbLimbs() - - hashIn := make([]uints.U8, 0, NumSignatures*perSignatureHashSize) - for i := 0; i < NumSignatures; i++ { - // hashIn += Pub[i].X - // Pay attention to the ordering! - for j := len(c.Pub[i].X.Limbs) - 1; j >= 0; j-- { - pubXLimb := uapi.UnpackMSB(uapi.ValueOf(c.Pub[i].X.Limbs[j])) - hashIn = append(hashIn, pubXLimb[:]...) - } - // hashIn += Pub[i].Y - for j := len(c.Pub[i].X.Limbs) - 1; j >= 0; j-- { - pubYLimb := uapi.UnpackMSB(uapi.ValueOf(c.Pub[i].Y.Limbs[j])) - hashIn = append(hashIn, pubYLimb[:]...) - } - // hashIn += Msg[i] - for j := len(c.Msg[i].Limbs) - 1; j >= 0; j-- { - msgLimb := uapi.UnpackMSB(uapi.ValueOf(c.Msg[i].Limbs[j])) - hashIn = append(hashIn, msgLimb[:]...) - } - } - h.Write(hashIn) - hashOutU8 := h.Sum() // Keccak256(Pub[0], Msg[0], Sig[1], Msg[1], ...)[0:32] - - // Commitment = hashoutU8[1:32] - hashOutU8[0] = uints.NewU8(0) // ignore the first byte, since BN254 order < uint256 - // Big endian [32]bytes to BigInt - for i := range hashOutU8 { - index := len(hashOutU8) - i - 1 - c.Commitment = api.MulAcc(c.Commitment, hashOutU8[index].Val, 1<<(i*8)) - } - return nil -} - -// Signature represents the signature for some message. -type Signature[Scalar emulated.FieldParams] struct { - R, S emulated.Element[Scalar] -} - -// PublicKey represents the public key to verify the signature for. -type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base] - -// Verify asserts that the signature sig verifies for the message msg and public -// key pk. The curve parameters params define the elliptic curve. -// -// We assume that the message msg is already hashed to the scalar field. -func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { - cr, err := sw_emulated.New[T, S](api, params) - if err != nil { - panic(err) - } - scalarApi, err := emulated.NewField[S](api) - if err != nil { - panic(err) - } - baseApi, err := emulated.NewField[T](api) - if err != nil { - panic(err) - } - pkpt := sw_emulated.AffinePoint[T](pk) - sInv := scalarApi.Inverse(&sig.S) - msInv := scalarApi.MulMod(msg, sInv) - rsInv := scalarApi.MulMod(&sig.R, sInv) - - // q = [rsInv]pkpt + [msInv]g - q := cr.JointScalarMulBase(&pkpt, rsInv, msInv) - qx := baseApi.Reduce(&q.X) - qxBits := baseApi.ToBits(qx) - rbits := scalarApi.ToBits(&sig.R) - if len(rbits) != len(qxBits) { - panic("non-equal lengths") - } - for i := range rbits { - api.AssertIsEqual(rbits[i], qxBits[i]) - } -} diff --git a/examples/p256/p256.go b/examples/p256/p256.go deleted file mode 100644 index caf38e5cc1..0000000000 --- a/examples/p256/p256.go +++ /dev/null @@ -1,252 +0,0 @@ -package p256 - -import ( - cryptoecdsa "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "io" - "log" - "math/big" - "os" - "strconv" - "time" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/groth16" - zeknox_bn254 "github.com/consensys/gnark/backend/groth16/bn254/zeknox" - "github.com/consensys/gnark/backend/solidity" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - gnark_io "github.com/consensys/gnark/io" - "github.com/consensys/gnark/std/math/emulated" - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/cryptobyte/asn1" - "golang.org/x/crypto/sha3" -) - -const NumSignatures = 10 - -var circuitName string - -func init() { - circuitName = "p256-" + strconv.Itoa(NumSignatures) -} - -func compileCircuit(newBuilder frontend.NewBuilder) (constraint.ConstraintSystem, error) { - circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} - r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), newBuilder, &circuit) - if err != nil { - return nil, err - } - return r1cs, nil -} - -func generateWitnessCircuit() EcdsaCircuit[emulated.P256Fp, emulated.P256Fr] { - witness := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} - perSignatureHashSize := 2*emulated.P256Fp{}.NbLimbs() + emulated.P256Fr{}.NbLimbs() - hashIn := make([]byte, 0, NumSignatures*perSignatureHashSize) - for i := 0; i < NumSignatures; i++ { - // Keygen - privKey, _ := cryptoecdsa.GenerateKey(elliptic.P256(), rand.Reader) - publicKey := privKey.PublicKey - - // Sign - msg, err := genRandomBytes(i + 20) - if err != nil { - panic(err) - } - msgHash := keccak256(msg) - sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) - - // Try verify - var ( - r, s = &big.Int{}, &big.Int{} - inner cryptobyte.String - ) - input := cryptobyte.String(sigBin) - if !input.ReadASN1(&inner, asn1.SEQUENCE) || - !input.Empty() || - !inner.ReadASN1Integer(r) || - !inner.ReadASN1Integer(s) || - !inner.Empty() { - log.Panicf("Invalid signature.") - } - flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) - if !flag { - log.Panicf("Can't verify signature.") - } - - // hashIn += Pub[i].X + Pub[i].Y + Msg[i] - pubX := publicKey.X.Bytes() - pubY := publicKey.Y.Bytes() - hashIn = append(hashIn, pubX[:]...) - hashIn = append(hashIn, pubY[:]...) - hashIn = append(hashIn, msgHash[:]...) - - // Assign to circuit witness - witness.Sig[i] = Signature[emulated.P256Fr]{ - R: emulated.ValueOf[emulated.P256Fr](r), - S: emulated.ValueOf[emulated.P256Fr](s), - } - witness.Msg[i] = emulated.ValueOf[emulated.P256Fr](msgHash[:]) - witness.Pub[i] = PublicKey[emulated.P256Fp, emulated.P256Fr]{ - X: emulated.ValueOf[emulated.P256Fp](publicKey.X), - Y: emulated.ValueOf[emulated.P256Fp](publicKey.Y), - } - } - hashOut := keccak256(hashIn) - hashOut[0] = 0 // ignore the first byte, since BN254 order < uint256 - witness.Commitment = hashOut[:] - return witness -} - -func generateWitness() (witness.Witness, error) { - witness := generateWitnessCircuit() - witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) - if err != nil { - panic(err) - } - - return witnessData, nil -} - -func Groth16Setup(fileDir string) { - r1cs, err := compileCircuit(r1cs.NewBuilder) - if err != nil { - panic(err) - } - pk, vk, err := groth16.Setup(r1cs) - if err != nil { - panic(err) - } - // Write to file - if _, err := os.Stat(fileDir); os.IsNotExist(err) { - err := os.MkdirAll(fileDir, os.ModePerm) - if err != nil { - panic(err) - } - } - WriteToFile(pk, fileDir+circuitName+".zkey") - WriteToFile(r1cs, fileDir+circuitName+".r1cs") - WriteToFile(vk, fileDir+circuitName+".vkey") -} - -func Groth16Prove(fileDir string) { - // Read r1cs - start := time.Now() - r1cs := groth16.NewCS(ecc.BN254) - ReadFromFile(r1cs, fileDir+circuitName+".r1cs") - elapsed := time.Since(start) - log.Printf("Read r1cs: %d ms", elapsed.Milliseconds()) - - // read zkey - start = time.Now() - pk := groth16.NewProvingKey(ecc.BN254) - if zeknox_bn254.HasZeknox { - defer pk.(*zeknox_bn254.ProvingKey).Free() - } - UnsafeReadFromFile(pk, fileDir+circuitName+".zkey") - elapsed = time.Since(start) - log.Printf("Read zkey: %d ms", elapsed.Milliseconds()) - - // Proof generation & verification - vk := groth16.NewVerifyingKey(ecc.BN254) - ReadFromFile(vk, fileDir+circuitName+".vkey") - - // CPU - for i := 0; i < 1; i++ { - log.Printf("------ CPU Prove %d ------", i+1) - witnessData, err := generateWitness() - if err != nil { - panic(err) - } - - proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16)) - if err != nil { - panic(err) - } - publicWitness, err := witnessData.Public() - if err != nil { - panic(err) - } - if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { - panic(err) - } - } - - // GPU - for i := 0; i < 1; i++ { - log.Printf("------ GPU Prove %d ------\n", i+1) - witnessData, err := generateWitness() - if err != nil { - panic(err) - } - - proof, err := groth16.Prove(r1cs, pk, witnessData, solidity.WithProverTargetSolidityVerifier(backend.GROTH16), backend.WithZeknoxAcceleration()) - if err != nil { - panic(err) - } - publicWitness, err := witnessData.Public() - if err != nil { - panic(err) - } - if err := groth16.Verify(proof, vk, publicWitness, solidity.WithVerifierTargetSolidityVerifier(backend.GROTH16)); err != nil { - log.Panicf("\nError in GPU Verify %d: %s\n\n", i+1, err) - } - } -} - -func genRandomBytes(size int) ([]byte, error) { - blk := make([]byte, size) - _, err := rand.Read(blk) - if err != nil { - return nil, err - } - return blk, nil -} - -func keccak256(data []byte) (digest [32]byte) { - h := sha3.NewLegacyKeccak256() - h.Write(data) - h.Sum(digest[:0]) - return -} - -func WriteToFile(data io.WriterTo, fileName string) { - file, err := os.Create(fileName) - if err != nil { - panic(err) - } - defer file.Close() - _, err = data.WriteTo(file) - if err != nil { - panic(err) - } -} - -func ReadFromFile(data io.ReaderFrom, fileName string) { - file, err := os.Open(fileName) - if err != nil { - panic(err) - } - defer file.Close() - // Use the ReadFrom method to read the file's content into data. - if _, err := data.ReadFrom(file); err != nil { - panic(err) - } -} - -// faster than readFromFile -func UnsafeReadFromFile(data gnark_io.UnsafeReaderFrom, fileName string) { - file, err := os.Open(fileName) - if err != nil { - panic(err) - } - defer file.Close() - if _, err := data.UnsafeReadFrom(file); err != nil { - panic(err) - } -} diff --git a/examples/p256/p256_test.go b/examples/p256/p256_test.go deleted file mode 100644 index 863773e13b..0000000000 --- a/examples/p256/p256_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package p256 - -import ( - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/test" -) - -func TestP256(t *testing.T) { - assert := test.NewAssert(t) - witnessCircuit := generateWitnessCircuit() - circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} - assert.CheckCircuit(&circuit, test.WithValidAssignment(&witnessCircuit), test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254), test.WithProverOpts(backend.WithZeknoxAcceleration())) -} From 24c51c041f228ac0296827b2e306dbae81327700 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Tue, 10 Dec 2024 18:42:05 +0800 Subject: [PATCH 60/62] delete comment --- go.mod | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.mod b/go.mod index d724ed9eba..95ee562819 100644 --- a/go.mod +++ b/go.mod @@ -37,5 +37,3 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) - -// replace github.com/okx/zeknox v0.0.0-20241023112133-1756b0ee9527 => /home/ubuntu/git/zeknox From 508042309707fe7acfbb87545eb0a4cc05051808 Mon Sep 17 00:00:00 2001 From: "jason.huang" <20609724+doutv@users.noreply.github.com> Date: Tue, 10 Dec 2024 19:06:06 +0800 Subject: [PATCH 61/62] add zeknox sha3 example --- README.md | 12 +++-- examples/zeknox/main.go | 114 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 examples/zeknox/main.go diff --git a/README.md b/README.md index 989abf9ac1..35787c62db 100644 --- a/README.md +++ b/README.md @@ -174,9 +174,16 @@ sudo cp libblst.a libzeknox.a /usr/local/lib/ If you want to build from source, see guide in https://github.com/okx/zeknox. ##### Enjoy GPU -Run `groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` + +`groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration())` + +```sh +go run -tags=zeknox examples/zeknox/main.go +# (place -tags before the filename) +``` ##### Test +Add following code to [mimc_test.go](examples/mimc/mimc_test.go) ```go assert.ProverSucceeded(&mimcCircuit, &Circuit{ PreImage: "16130099170765464552823636852555369511329944820189892919423002775646948828469", @@ -185,9 +192,6 @@ assert.ProverSucceeded(&mimcCircuit, &Circuit{ ``` ```sh -go run -tags=zeknox examples/main.go -# (place -tags before the filename) - go test github.com/consensys/gnark/examples/mimc -tags=prover_checks,zeknox ``` diff --git a/examples/zeknox/main.go b/examples/zeknox/main.go new file mode 100644 index 0000000000..c0cab6873d --- /dev/null +++ b/examples/zeknox/main.go @@ -0,0 +1,114 @@ +// main.go + +package main + +import ( + "log" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/sha3" + "github.com/consensys/gnark/std/math/uints" + cryptosha3 "golang.org/x/crypto/sha3" +) + +type sha3Circuit struct { + In []uints.U8 `gnark:",secret"` + Expected [32]uints.U8 `gnark:",public"` +} + +func (c *sha3Circuit) Define(api frontend.API) error { + h, err := sha3.New256(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + + h.Write(c.In) + res := h.Sum() + + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +const inputLength = 128 + +func compileCircuit(newBuilder frontend.NewBuilder) (constraint.ConstraintSystem, error) { + circuit := sha3Circuit{ + In: make([]uints.U8, inputLength), + } + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), newBuilder, &circuit) + if err != nil { + return nil, err + } + return r1cs, nil +} + +func generateWitness() (witness.Witness, error) { + input := make([]byte, inputLength) + dgst := cryptosha3.Sum256(input) + witness := sha3Circuit{ + In: uints.NewU8Array(input[:]), + } + copy(witness.Expected[:], uints.NewU8Array(dgst[:])) + + witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } + return witnessData, nil +} + +func main() { + r1cs, err := compileCircuit(r1cs.NewBuilder) + if err != nil { + panic(err) + } + pk, vk, err := groth16.Setup(r1cs) + if err != nil { + panic(err) + } + + // Witness generation + witnessData, err := generateWitness() + if err != nil { + panic(err) + } + publicWitness, err := witnessData.Public() + if err != nil { + panic(err) + } + + // GPU Prove & Verify + start := time.Now() + proofZeknox, err := groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) + if err != nil { + panic(err) + } + log.Printf("zeknox GPU prove: %d ms", time.Since(start).Milliseconds()) + if err := groth16.Verify(proofZeknox, vk, publicWitness); err != nil { + panic(err) + } + + // CPU Prove & Verify + start = time.Now() + proof, err := groth16.Prove(r1cs, pk, witnessData) + if err != nil { + panic(err) + } + log.Printf("CPU prove: %d ms", time.Since(start).Milliseconds()) + if err := groth16.Verify(proof, vk, publicWitness); err != nil { + panic(err) + } +} From ea139bc05ce761f42c17b9a28bbd6f70c08756f7 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 9 Jan 2025 15:52:43 +0800 Subject: [PATCH 62/62] add warmup and multiple runs flag --- examples/zeknox/main.go | 48 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/examples/zeknox/main.go b/examples/zeknox/main.go index c0cab6873d..bceab3edce 100644 --- a/examples/zeknox/main.go +++ b/examples/zeknox/main.go @@ -3,6 +3,7 @@ package main import ( + "flag" "log" "time" @@ -13,13 +14,14 @@ import ( "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/logger" "github.com/consensys/gnark/std/hash/sha3" "github.com/consensys/gnark/std/math/uints" cryptosha3 "golang.org/x/crypto/sha3" ) type sha3Circuit struct { - In []uints.U8 `gnark:",secret"` + In []uints.U8 `gnark:",secret"` Expected [32]uints.U8 `gnark:",public"` } @@ -71,6 +73,12 @@ func generateWitness() (witness.Witness, error) { } func main() { + logger.Disable() + + nRuns := flag.Int("r", 5, "number of runs") + flag.Parse() + log.Printf("Number of runs: %d", *nRuns) + r1cs, err := compileCircuit(r1cs.NewBuilder) if err != nil { panic(err) @@ -91,24 +99,54 @@ func main() { } // GPU Prove & Verify - start := time.Now() + // Warmup GPU proofZeknox, err := groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) if err != nil { panic(err) } - log.Printf("zeknox GPU prove: %d ms", time.Since(start).Milliseconds()) if err := groth16.Verify(proofZeknox, vk, publicWitness); err != nil { panic(err) } + // Actual run + tgpu := float64(0) + for i := 0; i < *nRuns; i++ { + start := time.Now() + proofZeknox, err = groth16.Prove(r1cs, pk, witnessData, backend.WithZeknoxAcceleration()) + if err != nil { + panic(err) + } + tgpu += float64(time.Since(start).Milliseconds()) + if err := groth16.Verify(proofZeknox, vk, publicWitness); err != nil { + panic(err) + } + } + tgpu /= float64(*nRuns) + log.Printf("zeknox GPU prove average time: %v ms", tgpu) // CPU Prove & Verify - start = time.Now() + // Warmup CPU proof, err := groth16.Prove(r1cs, pk, witnessData) if err != nil { panic(err) } - log.Printf("CPU prove: %d ms", time.Since(start).Milliseconds()) if err := groth16.Verify(proof, vk, publicWitness); err != nil { panic(err) } + // Actual run + tcpu := float64(0) + for i := 0; i < *nRuns; i++ { + start := time.Now() + proof, err = groth16.Prove(r1cs, pk, witnessData) + if err != nil { + panic(err) + } + tcpu += float64(time.Since(start).Milliseconds()) + if err := groth16.Verify(proof, vk, publicWitness); err != nil { + panic(err) + } + } + tcpu /= float64(*nRuns) + log.Printf("CPU prove average time: %v ms", tcpu) + + log.Printf("Speedup: %f", tcpu/tgpu) }