Skip to content

Commit

Permalink
banderwagon: avoid allocations in scalar field conversions (#30)
Browse files Browse the repository at this point in the history
* avoid allocations

Signed-off-by: Ignacio Hagopian <[email protected]>

* avoid extra allocs also in simple fr conversion

Signed-off-by: Ignacio Hagopian <[email protected]>

* avoid further allocs in ElementToBytes

Signed-off-by: Ignacio Hagopian <[email protected]>

Signed-off-by: Ignacio Hagopian <[email protected]>
  • Loading branch information
jsign authored Nov 11, 2022
1 parent a661476 commit 9aa5d42
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 41 deletions.
49 changes: 22 additions & 27 deletions banderwagon/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var Generator = Element{inner: bandersnatch.PointProj{
Y: bandersnatch.GetEdwardsCurve().Base.Y,
Z: fp.One(),
}}

var Identity = Element{inner: bandersnatch.PointProj{
X: fp.Zero(),
Y: fp.One(),
Expand All @@ -34,7 +35,7 @@ func (p Element) Bytes() [sizePointCompressed]byte {
affine_representation.FromProj(&p.inner)

// Serialisation takes the x co-ordinate and multiplies it by the sign of y
var x = affine_representation.X
x := affine_representation.X
if !affine_representation.Y.LexicographicallyLargest() {
x.Neg(&x)
}
Expand All @@ -44,15 +45,15 @@ func (p Element) Bytes() [sizePointCompressed]byte {
// Serialises multiple group elements using a batch multi inversion
func ElementsToBytes(elements []*Element) [][sizePointCompressed]byte {
// Collect all z co-ordinates
var zs []fp.Element
zs := make([]fp.Element, len(elements))
for i := 0; i < int(len(elements)); i++ {
zs = append(zs, elements[i].inner.Z)
zs[i] = elements[i].inner.Z
}

// Invert z co-ordinates
zInvs := fp.BatchInvert(zs)

var serialised_points [][sizePointCompressed]byte
serialised_points := make([][sizePointCompressed]byte, len(elements))

// Multiply x and y by zInv
for i := 0; i < int(len(elements)); i++ {
Expand All @@ -69,11 +70,10 @@ func ElementsToBytes(elements []*Element) [][sizePointCompressed]byte {
X.Neg(&X)
}

serialised_points = append(serialised_points, X.Bytes())
serialised_points[i] = X.Bytes()
}

return serialised_points

}

func (p *Element) setBytes(buf []byte, trusted bool) error {
Expand Down Expand Up @@ -116,49 +116,41 @@ func (p *Element) SetBytesTrusted(buf []byte) error {

// computes X/Y
func (p Element) mapToBaseField() fp.Element {

var res fp.Element
res.Div(&p.inner.X, &p.inner.Y)
return res
}

func (p Element) MapToScalarField() fr.Element {
func (p Element) MapToScalarField(res *fr.Element) {
basefield := p.mapToBaseField()
baseFieldBytes := basefield.BytesLE()

var res fr.Element
res.SetBytesLE(baseFieldBytes[:])

return res
}

// Maps each point to a field element in the scalar field
func MultiMapToScalarField(elements []*Element) []fr.Element {
func MultiMapToScalarField(result []*fr.Element, elements []*Element) {
if len(result) != len(elements) {
panic("MultiMapToScalarField expects the result slice to be the same length of elements")
}

// Collect all y co-ordinates
var ys []fp.Element
ys := make([]fp.Element, len(elements))
for i := 0; i < int(len(elements)); i++ {
ys = append(ys, elements[i].inner.Y)
ys[i] = elements[i].inner.Y
}

// Invert y co-ordinates
yInvs := fp.BatchInvert(ys)

var scalars []fr.Element

// Multiply x by yInv
for i := 0; i < int(len(elements)); i++ {
var mappedElement fp.Element

mappedElement.Mul(&elements[i].inner.X, &yInvs[i])
byts := mappedElement.BytesLE()

var res fr.Element
res.SetBytesLE(byts[:])
scalars = append(scalars, res)
result[i].SetBytesLE(byts[:])
}

return scalars

}

// TODO: change this to not use pointers
Expand Down Expand Up @@ -191,7 +183,7 @@ func (p *Element) Equal(other *Element) bool {
func subgroup_check(x fp.Element) error {
var res, one, ax_sq fp.Element
one.SetOne()
var A = bandersnatch.GetEdwardsCurve().A
A := bandersnatch.GetEdwardsCurve().A

// 1 - ax^2
ax_sq.Square(&x)
Expand All @@ -209,24 +201,27 @@ func (p *Element) Identity() *Element {
*p = Identity
return p
}

func (p *Element) Double(p1 *Element) *Element {
p.inner.Double(&p1.inner)
return p
}

func (p *Element) Add(p1, p2 *Element) *Element {
p.inner.Add(&p1.inner, &p2.inner)
return p
}

func (p *Element) AddMixed(p1 *Element, p2 bandersnatch.PointAffine) *Element {
p.inner.MixedAdd(&p1.inner, &p2)
return p
}

func (p *Element) Sub(p1, p2 *Element) *Element {
var neg_p2 Element
neg_p2.Neg(p2)

return p.Add(p1, &neg_p2)

}

func (p *Element) IsOnCurve() bool {
Expand All @@ -244,6 +239,7 @@ func (p *Element) Normalise() {
p.inner.Y.Set(&point_aff.Y)
p.inner.Z.SetOne()
}

func (p *Element) Set(p1 *Element) *Element {
p.inner.X.Set(&p1.inner.X)
p.inner.Y.Set(&p1.inner.Y)
Expand All @@ -255,6 +251,7 @@ func (p *Element) Neg(p1 *Element) *Element {
p.inner.Neg(&p1.inner)
return p
}

func (p *Element) ScalarMul(p1 *Element, scalar_mont *fr.Element) *Element {
p.inner.ScalarMul(&p1.inner, scalar_mont)
return p
Expand All @@ -269,7 +266,6 @@ func (p *Element) ScalarMul(p1 *Element, scalar_mont *fr.Element) *Element {
//
// we could increase storage by 2x and save CPU time by serialising the projective point
func UnsafeReadUncompressedPoint(r io.Reader) *Element {

affine_point := bandersnatch.ReadUncompressedPoint(r)
var proj_repr bandersnatch.PointProj
proj_repr.FromAffine(&affine_point)
Expand All @@ -281,7 +277,6 @@ func UnsafeReadUncompressedPoint(r io.Reader) *Element {

// Writes an uncompressed affine point to an io.Writer
func (element *Element) UnsafeWriteUncompressedPoint(w io.Writer) (int, error) {

// Convert underlying point to affine representation
var p bandersnatch.PointAffine
p.FromProj(&element.inner)
Expand Down
24 changes: 10 additions & 14 deletions banderwagon/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (

"github.com/crate-crypto/go-ipa/bandersnatch"
"github.com/crate-crypto/go-ipa/bandersnatch/fp"
"github.com/crate-crypto/go-ipa/bandersnatch/fr"
)

func TestEncodingFixedVectors(t *testing.T) {

expected_bit_strings := [16]string{
"4a2c7486fd924882bf02c6908de395122843e3e05264d7991e18e7985dad51e9",
"43aa74ef706605705989e8fd38df46873b7eae5921fbed115ac9d937399ce4d5",
Expand Down Expand Up @@ -90,6 +90,7 @@ func TestTwoTorsionEqual(t *testing.T) {
point.Double(&point)
}
}

func TestPointAtInfinityComponent(t *testing.T) {
// These are all points which will be shown to be on the curve
// but are not in the correct subgroup
Expand Down Expand Up @@ -124,11 +125,9 @@ func TestPointAtInfinityComponent(t *testing.T) {
panic("point should not be in the correct subgroup as it has an infinity component")
}
}

}

func TestAddSubDouble(t *testing.T) {

var A, B Element

A.Add(&Generator, &Generator)
Expand All @@ -149,7 +148,6 @@ func TestAddSubDouble(t *testing.T) {
}

func TestSerde(t *testing.T) {

var point Element
var point_aff bandersnatch.PointAffine

Expand All @@ -164,11 +162,9 @@ func TestSerde(t *testing.T) {
if !point_aff.Equal(&got) {
panic("deserialised point does not equal serialised point ")
}

}

func TestBatchElementsToBytes(t *testing.T) {

var A, B Element

A.Add(&Generator, &Generator)
Expand All @@ -183,34 +179,34 @@ func TestBatchElementsToBytes(t *testing.T) {
got_serialised_b := serialised_points[1]
if expected_serialised_a != got_serialised_a {
panic("expected serialised point of A is incorrect ")

}
if expected_serialised_b != got_serialised_b {
panic("expected serialised point of B is incorrect ")
}

}

func TestMultiMapToBaseField(t *testing.T) {

var A, B Element

A.Add(&Generator, &Generator)
B.Double(&Generator)
B.Double(&B)

expected_a := A.MapToScalarField()
expected_b := B.MapToScalarField()
var expected_a, expected_b fr.Element
A.MapToScalarField(&expected_a)
B.MapToScalarField(&expected_b)

scalars := MultiMapToScalarField([]*Element{&A, &B})
var ARes, BRes fr.Element
scalars := []*fr.Element{&ARes, &BRes}
MultiMapToScalarField(scalars, []*Element{&A, &B})

got_a := scalars[0]
got_b := scalars[1]
if expected_a != got_a {
if expected_a != *got_a {
panic("expected scalar for point `A` is incorrect ")
}

if expected_b != got_b {
if expected_b != *got_b {
panic("expected scalar for point `A` is incorrect ")
}
}

0 comments on commit 9aa5d42

Please sign in to comment.