From 6b09955cbe011509a5f36dfdb5fa1697887f0f3d Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 26 Aug 2024 17:40:46 -0700 Subject: [PATCH] Introduce a new CachedMask for BDN This new mask will pre-compute reusable values, speeding up repeated verification and aggregation of aggregate signatures (mostly the former). --- sign/bdn/bdn.go | 25 ++++------ sign/bdn/bdn_test.go | 41 ++++++++++++++++ sign/bdn/mask.go | 112 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 15 deletions(-) create mode 100644 sign/bdn/mask.go diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index b1a8c25e..ce220259 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -122,15 +122,13 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} -func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - publics := mask.Publics() - coefs, err := hashPointToR(publics) +func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } - agg := scheme.sigGroup.Point() - for i := range publics { + for i := range bdnMask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -152,7 +150,7 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber return nil, err } - sigC := sig.Clone().Mul(coefs[i], sig) + sigC := sig.Clone().Mul(bdnMask.coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -164,15 +162,14 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregatePublicKeys aggregates a set of public keys (similarly to // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. -func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - publics := mask.Publics() - coefs, err := hashPointToR(publics) +func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) { + bdnMask, err := newCachedMask(mask, false) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i, pub := range publics { + for i := range bdnMask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -181,9 +178,7 @@ func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) continue } - pubC := pub.Clone().Mul(coefs[i], pub) - pubC = pubC.Add(pubC, pub) - agg = agg.Add(agg, pubC) + agg = agg.Add(agg, bdnMask.getOrComputePubC(i)) } return agg, nil @@ -217,7 +212,7 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128} // Deprecated: use the new scheme methods instead. -func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { +func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask) } @@ -225,6 +220,6 @@ func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (k // AggregateSignatures for signatures) using the hash function // H: G2 -> R with R = {1, ..., 2^128}. // Deprecated: use the new scheme methods instead. -func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) { +func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregatePublicKeys(mask) } diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index cacfc99a..55f929aa 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "go.dedis.ch/kyber/v4" + "go.dedis.ch/kyber/v4/pairing/bls12381/kilic" "go.dedis.ch/kyber/v4/pairing/bn256" "go.dedis.ch/kyber/v4/sign" "go.dedis.ch/kyber/v4/sign/bls" @@ -161,6 +162,46 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { } } +func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) { + suite := kilic.NewBLS12381Suite() + schemeOnG2 := NewSchemeOnG2(suite) + + rng := random.New() + pubKeys := make([]kyber.Point, 3000) + privKeys := make([]kyber.Scalar, 3000) + for i := range pubKeys { + privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng) + } + + baseMask, err := sign.NewMask(pubKeys, nil) + require.NoError(b, err) + mask, err := NewCachedMask(baseMask) + require.NoError(b, err) + for i := range pubKeys { + require.NoError(b, mask.SetBit(i, true)) + } + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sigs := make([][]byte, len(privKeys)) + for i, k := range privKeys { + s, err := schemeOnG2.Sign(k, msg) + require.NoError(b, err) + sigs[i] = s + } + + sig, err := schemeOnG2.AggregateSignatures(sigs, mask) + require.NoError(b, err) + sigb, err := sig.MarshalBinary() + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pk, err := schemeOnG2.AggregatePublicKeys(mask) + require.NoError(b, err) + require.NoError(b, schemeOnG2.Verify(pk, msg, sigb)) + } +} + func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T { t.Helper() b, err := hex.DecodeString(s) diff --git a/sign/bdn/mask.go b/sign/bdn/mask.go new file mode 100644 index 00000000..f3899c70 --- /dev/null +++ b/sign/bdn/mask.go @@ -0,0 +1,112 @@ +package bdn + +import ( + "fmt" + + "go.dedis.ch/kyber/v4" + "go.dedis.ch/kyber/v4/sign" +) + +type Mask interface { + GetBit(i int) (bool, error) + SetBit(i int, enable bool) error + + IndexOfNthEnabled(nth int) int + NthEnabledAtIndex(idx int) int + + Publics() []kyber.Point + Participants() []kyber.Point + + CountEnabled() int + CountTotal() int + + Len() int + Mask() []byte + SetMask(mask []byte) error + Merge(mask []byte) error +} + +var _ Mask = (*sign.Mask)(nil) + +// We need to rename this, otherwise we have a public field named Mask (when we embed it) which +// conflicts with the function named Mask. It also makes it private, which is nice. +type maskI = Mask + +type CachedMask struct { + maskI + coefs []kyber.Scalar + pubKeyC []kyber.Point + // We could call Mask.Publics() instead of keeping these here, but that function copies the + // slice and this field lets us avoid that copy. + publics []kyber.Point +} + +// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms. +// +// This cached mask will: +// +// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated, +// distinct sets of signatures can be aggregated without any BLAKE2S hashing. +// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated, +// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders +// of magnitude faster than aggregating from scratch. +func NewCachedMask(mask Mask) (*CachedMask, error) { + return newCachedMask(mask, true) +} + +func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) { + if m, ok := mask.(*CachedMask); ok { + return m, nil + } + + publics := mask.Publics() + coefs, err := hashPointToR(publics) + if err != nil { + return nil, fmt.Errorf("failed to hash public keys: %w", err) + } + + cm := &CachedMask{ + maskI: mask, + coefs: coefs, + publics: publics, + } + + if precomputePubC { + pubKeyC := make([]kyber.Point, len(publics)) + for i := range publics { + pubKeyC[i] = cm.getOrComputePubC(i) + } + cm.pubKeyC = pubKeyC + } + + return cm, err +} + +// Clone copies the BDN mask while keeping the precomputed coefficients, etc. +func (cm *CachedMask) Clone() *CachedMask { + newMask, err := sign.NewMask(cm.publics, nil) + if err != nil { + // Not possible given that we didn't pass our own key. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + if err := newMask.SetMask(cm.Mask()); err != nil { + // Not possible given that we're using the same sized mask. + panic(fmt.Sprintf("failed to create mask: %s", err)) + } + return &CachedMask{ + maskI: newMask, + coefs: cm.coefs, + pubKeyC: cm.pubKeyC, + publics: cm.publics, + } +} + +func (cm *CachedMask) getOrComputePubC(i int) kyber.Point { + if cm.pubKeyC == nil { + // NOTE: don't cache here as we may be sharing this mask between threads. + pub := cm.publics[i] + pubC := pub.Clone().Mul(cm.coefs[i], pub) + return pubC.Add(pubC, pub) + } + return cm.pubKeyC[i] +}