diff --git a/ecc/bls12381/g1.go b/ecc/bls12381/g1.go index 5ecc5873f..7ad211948 100644 --- a/ecc/bls12381/g1.go +++ b/ecc/bls12381/g1.go @@ -392,7 +392,8 @@ func G1Generator() *G1 { } // affinize converts an entire slice to affine at once -func affinize(points []*G1) { +func affinize(points []*G1) (out []G1) { + out = make([]G1, len(points)) if len(points) == 0 { return } @@ -410,8 +411,9 @@ func affinize(points []*G1) { zinv.Mul(w, &ws[i]) w.Mul(w, &points[i].z) - points[i].x.Mul(&points[i].x, zinv) - points[i].y.Mul(&points[i].y, zinv) - points[i].z.SetOne() + out[i].x.Mul(&points[i].x, zinv) + out[i].y.Mul(&points[i].y, zinv) + out[i].z.SetOne() } + return } diff --git a/ecc/bls12381/g1_test.go b/ecc/bls12381/g1_test.go index 2f85211fd..cb601176e 100644 --- a/ecc/bls12381/g1_test.go +++ b/ecc/bls12381/g1_test.go @@ -218,17 +218,14 @@ func TestG1Affinize(t *testing.T) { N := 20 testTimes := 1 << 6 g1 := make([]*G1, N) - g2 := make([]*G1, N) for i := 0; i < testTimes; i++ { for j := 0; j < N; j++ { g1[j] = randomG1(t) - g2[j] = &G1{} - *g2[j] = *g1[j] } - affinize(g2) + g2 := affinize(g1) for j := 0; j < N; j++ { g1[j].toAffine() - if !g1[j].IsEqual(g2[j]) { + if !g1[j].IsEqual(&g2[j]) { t.Fatal("failure to preserve points") } if g2[j].z.IsEqual(&g1[j].z) != 1 { diff --git a/ecc/bls12381/pair.go b/ecc/bls12381/pair.go index ead99d03c..5cf108f3f 100644 --- a/ecc/bls12381/pair.go +++ b/ecc/bls12381/pair.go @@ -4,9 +4,10 @@ import "github.com/cloudflare/circl/ecc/bls12381/ff" // Pair calculates the ate-pairing of P and Q. func Pair(P *G1, Q *G2) *Gt { - P.toAffine() + affP := *P + affP.toAffine() mi := &ff.Fp12{} - miller(mi, P, Q) + miller(mi, &affP, Q) e := &Gt{} finalExp(e, mi) return e @@ -82,9 +83,9 @@ func ProdPair(P []*G1, Q []*G2, n []*Scalar) *Gt { out := new(ff.Fp12) out.SetOne() - affinize(P) - for i := range P { - miller(mi, P[i], Q[i]) + affineP := affinize(P) + for i := range affineP { + miller(mi, &affineP[i], Q[i]) nb, _ := n[i].MarshalBinary() ei.Exp(mi, nb) out.Mul(out, ei) @@ -105,13 +106,12 @@ func ProdPairFrac(P []*G1, Q []*G2, signs []int) *Gt { out := new(ff.Fp12) out.SetOne() - affinize(P) - for i := range P { - g := *P[i] + affineP := affinize(P) + for i := range affineP { if signs[i] == -1 { - g.Neg() + affineP[i].Neg() } - miller(mi, &g, Q[i]) + miller(mi, &affineP[i], Q[i]) out.Mul(mi, out) } diff --git a/ecc/bls12381/pair_test.go b/ecc/bls12381/pair_test.go index 46751d854..14ebe4b38 100644 --- a/ecc/bls12381/pair_test.go +++ b/ecc/bls12381/pair_test.go @@ -79,6 +79,52 @@ func TestProdPairFrac(t *testing.T) { } } +func TestInputs(t *testing.T) { + t.Run("Pair", func(t *testing.T) { + P := *randomG1(t) + Q := *randomG2(t) + oldP := P + oldQ := Q + _ = Pair(&P, &Q) + test.CheckOk(P == oldP, "the point P was overwritten", t) + test.CheckOk(Q == oldQ, "the point Q was overwritten", t) + }) + + t.Run("ProdPair", func(t *testing.T) { + P0, P1 := *randomG1(t), *randomG1(t) + Q0, Q1 := *randomG2(t), *randomG2(t) + n0, n1 := *randomScalar(t), *randomScalar(t) + + oldP0, oldP1 := P0, P1 + oldQ0, oldQ1 := Q0, Q1 + oldn0, oldn1 := n0, n1 + + _ = ProdPair([]*G1{&P0, &P1}, []*G2{&Q0, &Q1}, []*Scalar{&n0, &n1}) + + test.CheckOk(P0 == oldP0, "the point P0 was overwritten", t) + test.CheckOk(P1 == oldP1, "the point P1 was overwritten", t) + test.CheckOk(Q0 == oldQ0, "the point Q0 was overwritten", t) + test.CheckOk(Q1 == oldQ1, "the point Q1 was overwritten", t) + test.CheckOk(n0 == oldn0, "the scalar n0 was overwritten", t) + test.CheckOk(n1 == oldn1, "the scalar n1 was overwritten", t) + }) + + t.Run("ProdPairFrac", func(t *testing.T) { + P0, P1 := *randomG1(t), *randomG1(t) + Q0, Q1 := *randomG2(t), *randomG2(t) + + oldP0, oldP1 := P0, P1 + oldQ0, oldQ1 := Q0, Q1 + + _ = ProdPairFrac([]*G1{&P0, &P1}, []*G2{&Q0, &Q1}, []int{1, -1}) + + test.CheckOk(P0 == oldP0, "the point P0 was overwritten", t) + test.CheckOk(P1 == oldP1, "the point P1 was overwritten", t) + test.CheckOk(Q0 == oldQ0, "the point Q0 was overwritten", t) + test.CheckOk(Q1 == oldQ1, "the point Q1 was overwritten", t) + }) +} + func TestPairBilinear(t *testing.T) { testTimes := 1 << 5 for i := 0; i < testTimes; i++ {