diff --git a/banderwagon/element.go b/banderwagon/element.go index a8b0ca8..fb13382 100644 --- a/banderwagon/element.go +++ b/banderwagon/element.go @@ -1,6 +1,7 @@ package banderwagon import ( + "bytes" "errors" "fmt" "math/big" @@ -243,17 +244,27 @@ func (p *Element) SetBytesUncompressed(buf []byte, trusted bool) error { var x fp.Element x.SetBytes(buf[:coordinateSize]) - // subgroup check + var y fp.Element + // point in curve & subgroup check if !trusted { + point := bandersnatch.GetPointFromX(&x, true) + if point == nil { + return fmt.Errorf("point not in the curve") + } + calculatedYBytes := point.Y.Bytes() + if !bytes.Equal(calculatedYBytes[:], buf[coordinateSize:]) { + return fmt.Errorf("provided Y coordinate doesn't correspond to X") + } + y = point.Y + err := subgroupCheck(x) if err != nil { return err } + } else { + y.SetBytes(buf[coordinateSize:]) } - var y fp.Element - y.SetBytes(buf[coordinateSize:]) - *p = Element{inner: bandersnatch.PointProj{ X: x, Y: y, diff --git a/banderwagon/element_test.go b/banderwagon/element_test.go index 4143904..5429a05 100644 --- a/banderwagon/element_test.go +++ b/banderwagon/element_test.go @@ -327,6 +327,48 @@ func TestBatchNormalize(t *testing.T) { }) } +func TestSetUncompressedFail(t *testing.T) { + t.Parallel() + one := fp.One() + + t.Run("X not in curve", func(t *testing.T) { + startX := one + // Find in startX a point that isn't in the curve + for { + point := bandersnatch.GetPointFromX(&startX, true) + if point == nil { + break + } + startX.Add(&startX, &one) + continue + } + var serializedPoint [UncompressedSize]byte + xBytes := startX.Bytes() + yBytes := Generator.inner.Y.Bytes() // Use some valid-ish Y, but this shouldn't matter much. + copy(serializedPoint[:], xBytes[:]) + copy(serializedPoint[CompressedSize:], yBytes[:]) + + var point2 Element + if err := point2.SetBytesUncompressed(serializedPoint[:], false); err == nil { + t.Fatalf("the point must be rejected") + } + }) + + t.Run("wrong Y", func(t *testing.T) { + gen := Generator + // Despite X would lead to a point in the curve, + // we modify Y+1 to check the provided (serialized) Y + // coordinate isn't trusted blindly. + gen.inner.Y.Add(&gen.inner.Y, &one) + + pointBytes := gen.BytesUncompressed() + var point2 Element + if err := point2.SetBytesUncompressed(pointBytes[:], false); err == nil { + t.Fatalf("the point must be rejected") + } + }) +} + func FuzzDeserializationCompressed(f *testing.F) { f.Fuzz(func(t *testing.T, serializedpoint []byte) { var point Element