From 78547b486ee5be626ec8f92cb7aadf45202714cf Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Wed, 26 Apr 2023 12:51:21 -0300 Subject: [PATCH] tree: allow batching new leaf nodes Signed-off-by: Ignacio Hagopian tree: simplify batching of new leaves creation Signed-off-by: Ignacio Hagopian tree: fix insert new leaves test and benchmark Signed-off-by: Ignacio Hagopian remove comment Signed-off-by: Ignacio Hagopian remove conversion file Signed-off-by: Ignacio Hagopian remove unused method Signed-off-by: Ignacio Hagopian --- conversion.go | 159 -------------------------------------------------- tree.go | 156 +++++++++++++++++++++++++++++++++++++------------ tree_test.go | 83 ++++---------------------- 3 files changed, 130 insertions(+), 268 deletions(-) delete mode 100644 conversion.go diff --git a/conversion.go b/conversion.go deleted file mode 100644 index 82499d48..00000000 --- a/conversion.go +++ /dev/null @@ -1,159 +0,0 @@ -package verkle - -import ( - "bytes" - "fmt" - "runtime" - "sort" - "sync" -) - -// BatchNewLeafNodeData is a struct that contains the data needed to create a new leaf node. -type BatchNewLeafNodeData struct { - Stem []byte - Values map[byte][]byte -} - -// BatchNewLeafNode creates a new leaf node from the given data. It optimizes LeafNode creation -// by batching expensive cryptography operations. It returns the LeafNodes sorted by stem. -func BatchNewLeafNode(nodesValues []BatchNewLeafNodeData) []LeafNode { - cfg := GetConfig() - ret := make([]LeafNode, len(nodesValues)) - - numBatches := runtime.NumCPU() - batchSize := len(nodesValues) / numBatches - - var wg sync.WaitGroup - wg.Add(numBatches) - for i := 0; i < numBatches; i++ { - start := i * batchSize - end := (i + 1) * batchSize - if i == numBatches-1 { - end = len(nodesValues) - } - go func(ret []LeafNode, nodesValues []BatchNewLeafNodeData) { - defer wg.Done() - - c1c2points := make([]*Point, 2*len(nodesValues)) - c1c2frs := make([]*Fr, 2*len(nodesValues)) - for i, nv := range nodesValues { - valsslice := make([][]byte, NodeWidth) - for idx := range nv.Values { - valsslice[idx] = nv.Values[idx] - } - - ret[i] = *NewLeafNode(nv.Stem, valsslice) - - c1c2points[2*i], c1c2points[2*i+1] = ret[i].c1, ret[i].c2 - c1c2frs[2*i], c1c2frs[2*i+1] = new(Fr), new(Fr) - } - - toFrMultiple(c1c2frs, c1c2points) - - var poly [NodeWidth]Fr - poly[0].SetUint64(1) - for i, nv := range nodesValues { - StemFromBytes(&poly[1], nv.Stem) - poly[2] = *c1c2frs[2*i] - poly[3] = *c1c2frs[2*i+1] - - ret[i].commitment = cfg.CommitToPoly(poly[:], 252) - } - - }(ret[start:end], nodesValues[start:end]) - } - wg.Wait() - - sort.Slice(ret, func(i, j int) bool { - return bytes.Compare(ret[i].stem, ret[j].stem) < 0 - }) - - return ret -} - -// firstDiffByteIdx will return the first index in which the two stems differ. -// Both stems *must* be different. -func firstDiffByteIdx(stem1 []byte, stem2 []byte) int { - for i := range stem1 { - if stem1[i] != stem2[i] { - return i - } - } - panic("stems are equal") -} - -func (n *InternalNode) InsertMigratedLeaves(leaves []LeafNode, resolver NodeResolverFn) error { - for i := range leaves { - ln := leaves[i] - parent := n - - // Look for the appropriate parent for the leaf node. - for { - if hashedNode, ok := parent.children[ln.stem[parent.depth]].(*HashedNode); ok { - serialized, err := resolver(hashedNode.commitment) - if err != nil { - return fmt.Errorf("resolving node %x: %w", hashedNode.commitment, err) - } - resolved, err := ParseNode(serialized, parent.depth+1, hashedNode.commitment) - if err != nil { - return fmt.Errorf("parsing node %x: %w", serialized, err) - } - parent.children[ln.stem[parent.depth]] = resolved - } - - nextParent, ok := parent.children[ln.stem[parent.depth]].(*InternalNode) - if !ok { - break - } - - parent.cowChild(ln.stem[parent.depth]) - parent = nextParent - } - - switch node := parent.children[ln.stem[parent.depth]].(type) { - case Empty: - parent.cowChild(ln.stem[parent.depth]) - parent.children[ln.stem[parent.depth]] = &ln - ln.setDepth(parent.depth + 1) - case *LeafNode: - if bytes.Equal(node.stem, ln.stem) { - // In `ln` we have migrated key/values which should be copied to the leaf - // only if there isn't a value there. If there's a value, we skip it since - // our migrated value is stale. - nonPresentValues := make([][]byte, NodeWidth) - for i := range ln.values { - if node.values[i] == nil { - nonPresentValues[i] = ln.values[i] - } - } - - node.updateMultipleLeaves(nonPresentValues) - continue - } - - // Otherwise, we need to create the missing internal nodes depending in the fork point in their stems. - idx := firstDiffByteIdx(node.stem, ln.stem) - // We do a sanity check to make sure that the fork point is not before the current depth. - if byte(idx) <= parent.depth { - return fmt.Errorf("unexpected fork point %d for nodes %x and %x", idx, node.stem, ln.stem) - } - // Create the missing internal nodes. - for i := parent.depth + 1; i <= byte(idx); i++ { - nextParent := newInternalNode(parent.depth + 1).(*InternalNode) - parent.cowChild(ln.stem[parent.depth]) - parent.children[ln.stem[parent.depth]] = nextParent - parent = nextParent - } - // Add old and new leaf node to the latest created parent. - parent.cowChild(node.stem[parent.depth]) - parent.children[node.stem[parent.depth]] = node - node.setDepth(parent.depth + 1) - parent.cowChild(ln.stem[parent.depth]) - parent.children[ln.stem[parent.depth]] = &ln - ln.setDepth(parent.depth + 1) - default: - return fmt.Errorf("unexpected node type %T", node) - } - } - return nil -} diff --git a/tree.go b/tree.go index 56bedd4c..7dcd47ed 100644 --- a/tree.go +++ b/tree.go @@ -30,6 +30,8 @@ import ( "encoding/json" "errors" "fmt" + "runtime" + "sync" "github.com/crate-crypto/go-ipa/banderwagon" ) @@ -200,6 +202,7 @@ func (n *InternalNode) toExportable() *ExportableInternalNode { case *InternalNode: exportable.Children[i] = child.toExportable() case *LeafNode: + child.Commit() exportable.Children[i] = &ExportableLeafNode{ Stem: child.stem, Values: child.values, @@ -248,47 +251,14 @@ func NewStatelessInternal(depth byte, comm *Point) VerkleNode { // New creates a new leaf node func NewLeafNode(stem []byte, values [][]byte) *LeafNode { - cfg := GetConfig() - - // C1. - var c1poly [NodeWidth]Fr - var c1 *Point - count := fillSuffixTreePoly(c1poly[:], values[:NodeWidth/2]) - containsEmptyCodeHash := len(c1poly) >= EmptyCodeHashSecondHalfIdx && - c1poly[EmptyCodeHashFirstHalfIdx].Equal(&EmptyCodeHashFirstHalfValue) && - c1poly[EmptyCodeHashSecondHalfIdx].Equal(&EmptyCodeHashSecondHalfValue) - if containsEmptyCodeHash { - // Clear out values of the cached point. - c1poly[EmptyCodeHashFirstHalfIdx] = FrZero - c1poly[EmptyCodeHashSecondHalfIdx] = FrZero - // Calculate the remaining part of c1 and add to the base value. - partialc1 := cfg.CommitToPoly(c1poly[:], NodeWidth-count-2) - c1 = new(Point) - c1.Add(&EmptyCodeHashPoint, partialc1) - } else { - c1 = cfg.CommitToPoly(c1poly[:], NodeWidth-count) - } - - // C2. - var c2poly [NodeWidth]Fr - count = fillSuffixTreePoly(c2poly[:], values[NodeWidth/2:]) - c2 := cfg.CommitToPoly(c2poly[:], NodeWidth-count) - - // Root commitment preparation for calculation. - stem = stem[:StemSize] // enforce a 31-byte length - var poly [NodeWidth]Fr - poly[0].SetUint64(1) - StemFromBytes(&poly[1], stem) - toFrMultiple([]*Fr{&poly[2], &poly[3]}, []*Point{c1, c2}) - return &LeafNode{ // depth will be 0, but the commitment calculation // does not need it, and so it won't be free. values: values, stem: stem, - commitment: cfg.CommitToPoly(poly[:], NodeWidth-4), - c1: c1, - c2: c2, + commitment: nil, + c1: nil, + c2: nil, } } @@ -654,11 +624,33 @@ func (n *InternalNode) fillLevels(levels [][]*InternalNode) { } } +func (n *InternalNode) findNewLeafNodes(newLeaves []*LeafNode) []*LeafNode { + for idx := range n.cow { + child := n.children[idx] + if childInternalNode, ok := child.(*InternalNode); ok && len(childInternalNode.cow) > 0 { + newLeaves = childInternalNode.findNewLeafNodes(newLeaves) + } else if leafNode, ok := child.(*LeafNode); ok { + if leafNode.commitment == nil { + newLeaves = append(newLeaves, leafNode) + } + } + } + return newLeaves +} + func (n *InternalNode) Commit() *Point { if len(n.cow) == 0 { return n.commitment } + // New leaf nodes. + newLeaves := make([]*LeafNode, 0, 64) + newLeaves = n.findNewLeafNodes(newLeaves) + if len(newLeaves) > 0 { + batchCommitLeafNodes(newLeaves) + } + + // Internal nodes. internalNodeLevels := make([][]*InternalNode, StemSize) n.fillLevels(internalNodeLevels) @@ -1027,6 +1019,11 @@ func (n *LeafNode) updateCn(index byte, value []byte, c *Point) { } func (n *LeafNode) updateLeaf(index byte, value []byte) { + if n.commitment == nil { + n.values[index] = value + return + } + // Update the corresponding C1 or C2 commitment. var c *Point var oldC Point @@ -1051,6 +1048,15 @@ func (n *LeafNode) updateLeaf(index byte, value []byte) { } func (n *LeafNode) updateMultipleLeaves(values [][]byte) { + if n.commitment == nil { + for i, v := range values { + if len(v) != 0 && !bytes.Equal(v, n.values[i]) { + n.values[i] = v + } + } + return + } + var oldC1, oldC2 *Point // We iterate the values, and we update the C1 and/or C2 commitments depending on the index. @@ -1224,6 +1230,10 @@ func (n *LeafNode) Commitment() *Point { } func (n *LeafNode) Commit() *Point { + if n.commitment == nil { + commitLeafNodes([]*LeafNode{n}) + } + return n.commitment } @@ -1405,6 +1415,7 @@ func (n *LeafNode) GetProofItems(keys keylist) (*ProofElements, []byte, [][]byte // Serialize serializes a LeafNode. // The format is: func (n *LeafNode) Serialize() ([]byte, error) { + n.Commit() cBytes := banderwagon.ElementsToBytes([]*banderwagon.Element{n.c1, n.c2}) return n.serializeWithCompressedCommitments(cBytes[0], cBytes[1]), nil } @@ -1635,3 +1646,76 @@ func (n *LeafNode) serializeWithCompressedCommitments(c1Bytes [32]byte, c2Bytes return result } + +func batchCommitLeafNodes(leaves []*LeafNode) { + minBatchSize := 8 + if len(leaves) < minBatchSize { + commitLeafNodes(leaves) + return + } + + batchSize := len(leaves) / runtime.NumCPU() + if batchSize < minBatchSize { + batchSize = minBatchSize + } + + var wg sync.WaitGroup + for start := 0; start < len(leaves); start += batchSize { + end := start + batchSize + if end > len(leaves) { + end = len(leaves) + } + wg.Add(1) + go func(leaves []*LeafNode) { + defer wg.Done() + commitLeafNodes(leaves) + }(leaves[start:end]) + } + wg.Wait() +} + +func commitLeafNodes(leaves []*LeafNode) { + cfg := GetConfig() + + c1c2points := make([]*Point, 2*len(leaves)) + c1c2frs := make([]*Fr, 2*len(leaves)) + for i, n := range leaves { + // C1. + var c1poly [NodeWidth]Fr + count := fillSuffixTreePoly(c1poly[:], n.values[:NodeWidth/2]) + containsEmptyCodeHash := len(c1poly) >= EmptyCodeHashSecondHalfIdx && + c1poly[EmptyCodeHashFirstHalfIdx].Equal(&EmptyCodeHashFirstHalfValue) && + c1poly[EmptyCodeHashSecondHalfIdx].Equal(&EmptyCodeHashSecondHalfValue) + if containsEmptyCodeHash { + // Clear out values of the cached point. + c1poly[EmptyCodeHashFirstHalfIdx] = FrZero + c1poly[EmptyCodeHashSecondHalfIdx] = FrZero + // Calculate the remaining part of c1 and add to the base value. + partialc1 := cfg.CommitToPoly(c1poly[:], NodeWidth-count-2) + n.c1 = new(Point) + n.c1.Add(&EmptyCodeHashPoint, partialc1) + } else { + n.c1 = cfg.CommitToPoly(c1poly[:], NodeWidth-count) + } + + // C2. + var c2poly [NodeWidth]Fr + count = fillSuffixTreePoly(c2poly[:], n.values[NodeWidth/2:]) + n.c2 = cfg.CommitToPoly(c2poly[:], NodeWidth-count) + + c1c2points[2*i], c1c2points[2*i+1] = n.c1, n.c2 + c1c2frs[2*i], c1c2frs[2*i+1] = new(Fr), new(Fr) + } + + toFrMultiple(c1c2frs, c1c2points) + + var poly [NodeWidth]Fr + poly[0].SetUint64(1) + for i, nv := range leaves { + StemFromBytes(&poly[1], nv.stem) + poly[2] = *c1c2frs[2*i] + poly[3] = *c1c2frs[2*i+1] + + nv.commitment = cfg.CommitToPoly(poly[:], 252) + } +} diff --git a/tree_test.go b/tree_test.go index 4fbbbee6..a28d5d51 100644 --- a/tree_test.go +++ b/tree_test.go @@ -885,6 +885,7 @@ func TestLeafToCommsLessThan16(*testing.T) { func TestGetProofItemsNoPoaIfStemPresent(t *testing.T) { root := New() root.Insert(ffx32KeyTest, zeroKeyTest, nil) + root.Commit() // insert two keys that differ from the inserted stem // by one byte. @@ -1195,6 +1196,7 @@ func TestEmptyHashCodeCachedPoint(t *testing.T) { values := make([][]byte, NodeWidth) values[CodeHashVectorPosition] = emptyHashCode ln := NewLeafNode(zeroKeyTest, values) + ln.Commit() // Compare the result (which used the cached point) with the expected result which was // calculated by a previous version of the library that didn't use a cached point. @@ -1208,18 +1210,17 @@ func TestEmptyHashCodeCachedPoint(t *testing.T) { } } -func TestBatchMigratedKeyValues(t *testing.T) { +func TestInsertNewLeaves(t *testing.T) { _ = GetConfig() for _, treeInitialKeyValCount := range []int{0, 500, 1_000, 2_000, 5_000} { fmt.Printf("Assuming %d key/values touched by block execution:\n", treeInitialKeyValCount) for _, migrationKeyValueCount := range []int{1_000, 2_000, 5_000, 8_000} { - iterations := 5 - var batchedDuration, unbatchedDuration time.Duration + iterations := 10 + var unbatchedDuration time.Duration for i := 0; i < iterations; i++ { runtime.GC() - // ***Insert the key pairs 'naively' *** rand := mRand.New(mRand.NewSource(42)) //skipcq: GSC-G404 tree := genRandomTree(rand, treeInitialKeyValCount) randomKeyValues := genRandomKeyValues(rand, migrationKeyValueCount) @@ -1230,60 +1231,17 @@ func TestBatchMigratedKeyValues(t *testing.T) { t.Fatalf("failed to insert key: %v", err) } } - unbatchedRoot := tree.Commit().Bytes() + tree.Commit() if _, err := tree.(*InternalNode).BatchSerialize(); err != nil { t.Fatalf("failed to serialize unbatched tree: %v", err) } unbatchedDuration += time.Since(now) - - // ***Insert the key pairs with optimized strategy & methods*** - rand = mRand.New(mRand.NewSource(42)) //skipcq: GSC-G404 - tree = genRandomTree(rand, treeInitialKeyValCount) - randomKeyValues = genRandomKeyValues(rand, migrationKeyValueCount) - - now = time.Now() - // Create LeafNodes in batch mode. - nodeValues := make([]BatchNewLeafNodeData, 0, len(randomKeyValues)) - curr := BatchNewLeafNodeData{ - Stem: randomKeyValues[0].key[:StemSize], - Values: map[byte][]byte{randomKeyValues[0].key[StemSize]: randomKeyValues[0].value}, - } - for _, kv := range randomKeyValues[1:] { - if bytes.Equal(curr.Stem, kv.key[:StemSize]) { - curr.Values[kv.key[StemSize]] = kv.value - continue - } - nodeValues = append(nodeValues, curr) - curr = BatchNewLeafNodeData{ - Stem: kv.key[:StemSize], - Values: map[byte][]byte{kv.key[StemSize]: kv.value}, - } - } - // Append last remaining node. - nodeValues = append(nodeValues, curr) - - // Create all leaves in batch mode so we can optimize cryptography operations. - newLeaves := BatchNewLeafNode(nodeValues) - if err := tree.(*InternalNode).InsertMigratedLeaves(newLeaves, nil); err != nil { - t.Fatalf("failed to insert key: %v", err) - } - - batchedRoot := tree.Commit().Bytes() - if _, err := tree.(*InternalNode).BatchSerialize(); err != nil { - t.Fatalf("failed to serialize batched tree: %v", err) - } - batchedDuration += time.Since(now) - - if unbatchedRoot != batchedRoot { - t.Fatalf("expected %x, got %x", unbatchedRoot, batchedRoot) - } } - fmt.Printf("\tIf %d extra key-values are migrated: unbatched %dms, batched %dms, %.02fx\n", migrationKeyValueCount, (unbatchedDuration / time.Duration(iterations)).Milliseconds(), (batchedDuration / time.Duration(iterations)).Milliseconds(), float64(unbatchedDuration.Milliseconds())/float64(batchedDuration.Milliseconds())) + fmt.Printf("\tIf %d extra key-values are migrated: unbatched %dms\n", migrationKeyValueCount, (unbatchedDuration / time.Duration(iterations)).Milliseconds()) } } } - func genRandomTree(rand *mRand.Rand, keyValueCount int) VerkleNode { tree := New() for _, kv := range genRandomKeyValues(rand, keyValueCount) { @@ -1310,7 +1268,7 @@ func genRandomKeyValues(rand *mRand.Rand, count int) []keyValue { return ret } -func BenchmarkBatchLeavesInsert(b *testing.B) { +func BenchmarkNewLeavesInsert(b *testing.B) { treeInitialKeyValCount := 1_000 migrationKeyValueCount := 5_000 @@ -1326,31 +1284,10 @@ func BenchmarkBatchLeavesInsert(b *testing.B) { b.StartTimer() // Create LeafNodes in batch mode. - nodeValues := make([]BatchNewLeafNodeData, 0, len(randomKeyValues)) - curr := BatchNewLeafNodeData{ - Stem: randomKeyValues[0].key[:StemSize], - Values: map[byte][]byte{randomKeyValues[0].key[StemSize]: randomKeyValues[0].value}, - } for _, kv := range randomKeyValues[1:] { - if bytes.Equal(curr.Stem, kv.key[:StemSize]) { - curr.Values[kv.key[StemSize]] = kv.value - continue - } - nodeValues = append(nodeValues, curr) - curr = BatchNewLeafNodeData{ - Stem: kv.key[:StemSize], - Values: map[byte][]byte{kv.key[StemSize]: kv.value}, - } + tree.Insert(kv.key, kv.value, nil) } - // Append last remaining node. - nodeValues = append(nodeValues, curr) - - // Create all leaves in batch mode so we can optimize cryptography operations. - newLeaves := BatchNewLeafNode(nodeValues) - if err := tree.(*InternalNode).InsertMigratedLeaves(newLeaves, nil); err != nil { - b.Fatalf("failed to insert key: %v", err) - } - + tree.Commit() if _, err := tree.(*InternalNode).BatchSerialize(); err != nil { b.Fatalf("failed to serialize batched tree: %v", err) }