Skip to content

Commit

Permalink
refactoring(share/ipld): rework getLeavesByNamespace
Browse files Browse the repository at this point in the history
  • Loading branch information
vgonkivs committed Mar 7, 2023
1 parent 0b81f59 commit 28284dd
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 86 deletions.
18 changes: 10 additions & 8 deletions share/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
"github.com/ipfs/go-cid"
format "github.com/ipfs/go-ipld-format"

"github.com/celestiaorg/celestia-node/share/ipld"
"github.com/celestiaorg/nmt/namespace"

"github.com/celestiaorg/celestia-node/share/ipld"
)

// GetShare fetches and returns the data for leaf `leafIndex` of root `rootCid`.
Expand Down Expand Up @@ -49,24 +50,25 @@ func GetSharesByNamespace(
root cid.Cid,
nID namespace.ID,
maxShares int,
proofContainer *ipld.Proof,
) ([]Share, error) {
) ([]Share, *ipld.Proof, error) {
ctx, span := tracer.Start(ctx, "get-shares-by-namespace")
defer span.End()

leaves, err := ipld.GetLeavesByNamespace(ctx, bGetter, root, nID, maxShares, proofContainer)
if err != nil && leaves == nil {
return nil, err
data := ipld.NewRetrievedData(maxShares, ipld.WithLeaves(), ipld.WithProofs())
err := ipld.GetLeavesByNamespace(ctx, bGetter, root, nID, data)
if err != nil {
return nil, nil, err
}

leaves := data.CollectLeaves()

shares := make([]Share, len(leaves))
for i, leaf := range leaves {
if leaf != nil {
shares[i] = leafToShare(leaf)
}
}

return shares, err
return shares, data.CollectProofs(), err
}

// leafToShare converts an NMT leaf into a Share.
Expand Down
19 changes: 12 additions & 7 deletions share/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestGetSharesByNamespace(t *testing.T) {
var shares []Share
for _, row := range eds.RowRoots() {
rcid := ipld.MustCidFromNamespacedSha256(row)
rowShares, err := GetSharesByNamespace(ctx, bServ, rcid, nID, len(eds.RowRoots()), nil)
rowShares, _, err := GetSharesByNamespace(ctx, bServ, rcid, nID, len(eds.RowRoots()))
require.NoError(t, err)

shares = append(shares, rowShares...)
Expand Down Expand Up @@ -222,7 +222,9 @@ func TestGetLeavesByNamespace_IncompleteData(t *testing.T) {
err = bServ.DeleteBlock(ctx, r.Cid())
require.NoError(t, err)

leaves, err := ipld.GetLeavesByNamespace(ctx, bServ, rcid, nid, len(shares), nil)
rData := ipld.NewRetrievedData(len(shares), ipld.WithLeaves())
err = ipld.GetLeavesByNamespace(ctx, bServ, rcid, nid, rData)
leaves := rData.CollectLeaves()
assert.Nil(t, leaves[1])
assert.Equal(t, 4, len(leaves))
require.Error(t, err)
Expand Down Expand Up @@ -303,9 +305,10 @@ func TestGetLeavesByNamespace_MultipleRowsContainingSameNamespaceId(t *testing.T

for _, row := range eds.RowRoots() {
rcid := ipld.MustCidFromNamespacedSha256(row)
leaves, err := ipld.GetLeavesByNamespace(ctx, bServ, rcid, nid, len(shares), nil)
data := ipld.NewRetrievedData(len(shares), ipld.WithLeaves())
err := ipld.GetLeavesByNamespace(ctx, bServ, rcid, nid, data)
assert.Nil(t, err)

leaves := data.CollectLeaves()
for _, node := range leaves {
// test that the data returned by getLeavesByNamespace for nid
// matches the commonNamespaceData that was copied across almost all data
Expand Down Expand Up @@ -353,10 +356,10 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) {
var shares []Share
for _, row := range eds.RowRoots() {
rcid := ipld.MustCidFromNamespacedSha256(row)
proof := new(ipld.Proof)
rowShares, err := GetSharesByNamespace(ctx, bServ, rcid, nID, len(eds.RowRoots()), proof)
rowShares, proof, err := GetSharesByNamespace(ctx, bServ, rcid, nID, len(eds.RowRoots()))
require.NoError(t, err)
if rowShares != nil {
require.NotNil(t, proof)
// append shares to check integrity later
shares = append(shares, rowShares...)

Expand Down Expand Up @@ -448,7 +451,9 @@ func assertNoRowContainsNID(

// for each row root cid check if the minNID exists
for _, rowCID := range rowRootCIDs {
leaves, err := ipld.GetLeavesByNamespace(context.Background(), bServ, rowCID, nID, rowRootCount, nil)
data := ipld.NewRetrievedData(rowRootCount, ipld.WithProofs())
err := ipld.GetLeavesByNamespace(context.Background(), bServ, rowCID, nID, data)
leaves := data.CollectLeaves()
assert.Nil(t, leaves)
assert.Nil(t, err)
}
Expand Down
3 changes: 1 addition & 2 deletions share/getters/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func collectSharesByNamespace(
// shadow loop variables, to ensure correct values are captured
i, rootCID := i, rootCID
errGroup.Go(func() error {
proof := new(ipld.Proof)
row, err := share.GetSharesByNamespace(ctx, bg, rootCID, nID, len(root.RowsRoots), proof)
row, proof, err := share.GetSharesByNamespace(ctx, bg, rootCID, nID, len(root.RowsRoots))
shares[i] = share.NamespacedRow{
Shares: row,
Proof: proof,
Expand Down
88 changes: 19 additions & 69 deletions share/ipld/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,24 @@ func GetLeaves(ctx context.Context,
wg.Wait()
}

// GetLeavesByNamespace returns leaves and corresponding proof that could be used to verify leaves
// GetLeavesByNamespace collects leaves and corresponding proof that could be used to verify leaves
// inclusion. It returns as many leaves from the given root with the given namespace.ID as it can
// retrieve. If no shares are found, it returns both data and error as nil. If non-nil
// proofContainer param passed, it will be filled with data required for inclusion verification. A
// retrieve. If no shares are found, it returns error as nil. A
// non-nil error means that only partial data is returned, because at least one share retrieval
// failed. The following implementation is based on `GetShares`.
func GetLeavesByNamespace(
ctx context.Context,
bGetter blockservice.BlockGetter,
root cid.Cid,
nID namespace.ID,
maxShares int,
proofContainer *Proof,
) ([]ipld.Node, error) {
retrievedData *RetrievedData,
) error {
if len(nID) != NamespaceSize {
return nil, fmt.Errorf("expected namespace ID of size %d, got %d", NamespaceSize, len(nID))
return fmt.Errorf("expected namespace ID of size %d, got %d", NamespaceSize, len(nID))
}

if err := retrievedData.validateBasic(); err != nil {
return err
}

ctx, span := tracer.Start(ctx, "get-leaves-by-namespace")
Expand All @@ -196,14 +198,9 @@ func GetLeavesByNamespace(
attribute.String("root", root.String()),
)

// we don't know where in the tree the leaves in the namespace are,
// so we keep track of the bounds to return the correct slice
// maxShares acts as a sentinel to know if we find any leaves
bounds := fetchedBounds{int64(maxShares), 0}

// buffer the jobs to avoid blocking, we only need as many
// queued as the number of shares in the second-to-last layer
jobs := make(chan *job, (maxShares+1)/2)
jobs := make(chan *job, (retrievedData.getMaxShares()+1)/2)
jobs <- &job{id: root, ctx: ctx}

var wg chanGroup
Expand All @@ -215,41 +212,23 @@ func GetLeavesByNamespace(
retrievalErr error
)

// we overallocate space for leaves since we do not know how many we will find
// on the level above, the length of the Row is passed in as maxShares
leaves := make([]ipld.Node, maxShares)

// if non-nil proof container provided, collect proofs while traversing the tree and fill put them
// into container after
var collectProofs = proofContainer != nil
var proofs *proofCollector
if collectProofs {
proofs = newProofCollector(maxShares)
}

for {
var j *job
var ok bool
select {
case j, ok = <-jobs:
case <-ctx.Done():
return nil, ctx.Err()
return ctx.Err()
}

if !ok {
// if there were no leaves under the given root in the given namespace,
// both return values are nil. otherwise, the error will also be non-nil.
if bounds.lowest == int64(maxShares) {
return nil, retrievalErr
}

if collectProofs {
proofContainer.Start = int(bounds.lowest)
proofContainer.End = int(bounds.highest) + 1
proofContainer.Nodes = proofs.Nodes()
// leaves and error will be nil. otherwise, the error will also be non-nil.
if !retrievedData.leavesAvailable() {
return retrievalErr
}

return leaves[bounds.lowest : bounds.highest+1], retrievalErr
return retrievalErr
}
pool.Submit(func() {
ctx, span := tracer.Start(j.ctx, "process-job")
Expand All @@ -271,18 +250,16 @@ func GetLeavesByNamespace(
log.Errorw("getLeavesWithProofsByNamespace: could not retrieve node", "nID", nID, "pos", j.sharePos, "err", err)
span.SetStatus(codes.Error, err.Error())
// we still need to update the bounds
bounds.update(int64(j.sharePos))
retrievedData.addLeaf(j.sharePos, nil)
return
}

links := nd.Links()
if len(links) == 0 {
// successfully fetched a leaf belonging to the namespace
span.SetStatus(codes.Ok, "")
leaves[j.sharePos] = nd
// we found a leaf, so we update the bounds
// the update routine is repeated until the atomic swap is successful
bounds.update(int64(j.sharePos))
retrievedData.addLeaf(j.sharePos, nd)
return
}

Expand All @@ -304,17 +281,13 @@ func GetLeavesByNamespace(

// proof is on the right side, if the nID is less than min namespace of jobNid
if nID.Less(nmt.MinNamespace(jobNid, nID.Size())) {
if collectProofs {
proofs.addRight(lnk.Cid, newJob.depth)
}
retrievedData.addProof(right, lnk.Cid, newJob.depth)
continue
}

// proof is on the left side, if the nID is bigger than max namespace of jobNid
if !nID.LessOrEqual(nmt.MaxNamespace(jobNid, nID.Size())) {
if collectProofs {
proofs.addLeft(lnk.Cid, newJob.depth)
}
retrievedData.addProof(left, lnk.Cid, newJob.depth)
continue
}

Expand Down Expand Up @@ -391,29 +364,6 @@ func (w *chanGroup) done() {
}
}

type fetchedBounds struct {
lowest int64
highest int64
}

// update checks if the passed index is outside the current bounds,
// and updates the bounds atomically if it extends them.
func (b *fetchedBounds) update(index int64) {
lowest := atomic.LoadInt64(&b.lowest)
// try to write index to the lower bound if appropriate, and retry until the atomic op is successful
// CAS ensures that we don't overwrite if the bound has been updated in another goroutine after the
// comparison here
for index < lowest && !atomic.CompareAndSwapInt64(&b.lowest, lowest, index) {
lowest = atomic.LoadInt64(&b.lowest)
}
// we always run both checks because element can be both the lower and higher bound
// for example, if there is only one share in the namespace
highest := atomic.LoadInt64(&b.highest)
for index > highest && !atomic.CompareAndSwapInt64(&b.highest, highest, index) {
highest = atomic.LoadInt64(&b.highest)
}
}

// job represents an encountered node to investigate during the `GetLeaves`
// and `GetLeavesByNamespace` routines.
type job struct {
Expand Down
Loading

0 comments on commit 28284dd

Please sign in to comment.