From 10605188ee34484cfef86b77ad61523b1586bfae Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Fri, 9 Aug 2024 14:51:09 -0700 Subject: [PATCH] Remove n^2 algorithm from signature/key aggregation CountEnabled and IndexOfNthEnabled are both O(n) in the size of the mask, making this loop n^2. The BLS operations still tend to be the slow part, but the n^2 factor will start to show up with thousands of keys. --- sign/bdn/bdn.go | 45 ++++++++++++++++++++++++++------------------- sign/mask.go | 11 +++++++++++ 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index 4b1ab1b9c..d6a86edc5 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -12,6 +12,7 @@ package bdn import ( "crypto/cipher" "errors" + "fmt" "math/big" "github.com/drand/kyber" @@ -129,31 +130,36 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - if len(sigs) != mask.CountEnabled() { - return nil, errors.New("length of signatures and public keys must match") - } - - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.sigGroup.Point() - for i, buf := range sigs { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { - // this should never happen as we check the lenths at the beginning - // an error here is probably a bug in the mask - return nil, errors.New("couldn't find the index") + for i := range publics { + if enabled, err := mask.GetBit(i); err != nil { + // this should never happen because of the loop boundary + // an error here is probably a bug in the mask implementation + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } + if len(sigs) == 0 { + return nil, errors.New("length of signatures and public keys must match") + } + + buf := sigs[0] + sigs = sigs[1:] + sig := scheme.sigGroup.Point() err = sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(coefs[peerIndex], sig) + sigC := sig.Clone().Mul(coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -166,22 +172,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i := 0; i < mask.CountEnabled(); i++ { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { + for i, pub := range publics { + if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation - return nil, errors.New("couldn't find the index") + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } - pub := mask.Publics()[peerIndex] - pubC := pub.Clone().Mul(coefs[peerIndex], pub) + pubC := pub.Clone().Mul(coefs[i], pub) pubC = pubC.Add(pubC, pub) agg = agg.Add(agg, pubC) } diff --git a/sign/mask.go b/sign/mask.go index 98e96f0f6..f0b97a144 100644 --- a/sign/mask.go +++ b/sign/mask.go @@ -59,6 +59,17 @@ func (m *Mask) SetMask(mask []byte) error { return nil } +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + return m.mask[byteIndex]&mask != 0, nil +} + // SetBit turns on or off the bit at the given index. func (m *Mask) SetBit(i int, enable bool) error { if i >= len(m.publics) || i < 0 {