Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: UnmarshalJSON is limited to the default Tree bug #277

Merged
merged 37 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
80f81c4
refator default tree to seperated file
sontrinh16 Dec 12, 2023
5ee186c
add tree contructors global map
sontrinh16 Dec 13, 2023
936527b
Merge branch 'main' into global_tree
sontrinh16 Dec 13, 2023
cde39fa
accept tree name as arg instead
sontrinh16 Dec 13, 2023
3502a53
pull upstream
sontrinh16 Dec 13, 2023
c79bb47
register test tree constructor
sontrinh16 Dec 13, 2023
857a175
naming convention for tree name
sontrinh16 Dec 13, 2023
503c757
fix test
sontrinh16 Dec 13, 2023
e52fbdf
fix lint
sontrinh16 Dec 13, 2023
404dfa1
Merge branch 'main' into global_tree
sontrinh16 Dec 13, 2023
15cfd82
move constructor fetching to unmarshalJSON
sontrinh16 Dec 14, 2023
051f085
pull upstream
sontrinh16 Dec 14, 2023
28aac5b
minor
sontrinh16 Dec 14, 2023
4e25dd3
adding comment and privatise func
sontrinh16 Dec 15, 2023
98c9f3a
use sync map
sontrinh16 Dec 15, 2023
55db3a0
add comment
sontrinh16 Dec 15, 2023
b523182
fix lint
sontrinh16 Dec 15, 2023
f910bcf
minor public func command change
sontrinh16 Dec 15, 2023
a855577
add todo
sontrinh16 Dec 15, 2023
3b59a47
add unit test for treeFns global map
sontrinh16 Dec 15, 2023
b4c5c90
add clean up function
sontrinh16 Dec 15, 2023
c2b5e32
minor update and remove unnecessary testcases
sontrinh16 Dec 15, 2023
ed86b82
addback invalid interface testcases
sontrinh16 Dec 15, 2023
7c8dcad
minor comment update
sontrinh16 Dec 15, 2023
e51b12a
Merge branch 'main' into global_tree
sontrinh16 Dec 20, 2023
2d5247b
Merge branch 'main' into global_tree
sontrinh16 Jan 3, 2024
b53cf8c
minor update
sontrinh16 Jan 5, 2024
64f3e0a
Merge branch 'main' into global_tree
sontrinh16 Jan 5, 2024
2c07d00
add comment for test functions
sontrinh16 Jan 5, 2024
a7d1eb0
minor
sontrinh16 Jan 5, 2024
a2baef2
add default tree as fallback
sontrinh16 Jan 10, 2024
0b16ad8
fix typo
sontrinh16 Jan 10, 2024
3c31d8e
use assert for tree fn compare
sontrinh16 Jan 10, 2024
8067e13
lint fix
sontrinh16 Jan 10, 2024
6dde5d7
add test for unmarshal json
sontrinh16 Jan 10, 2024
002c3b0
minor nit
sontrinh16 Jan 10, 2024
c988b3d
Merge branch 'main' into global_tree
sontrinh16 Jan 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions default_tree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package rsmt2d

import (
"crypto/sha256"
"fmt"

"github.com/celestiaorg/merkletree"
)

var DefaultTreeName = "default-tree"

func init() {
err := RegisterTree(DefaultTreeName, NewDefaultTree)
if err != nil {
panic(fmt.Sprintf("%s already registered", DefaultTreeName))

Check warning on line 15 in default_tree.go

View check run for this annotation

Codecov / codecov/patch

default_tree.go#L15

Added line #L15 was not covered by tests
}
}

var _ Tree = &DefaultTree{}

type DefaultTree struct {
*merkletree.Tree
leaves [][]byte
root []byte
}

func NewDefaultTree(_ Axis, _ uint) Tree {
return &DefaultTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
}
}

func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
}
d.root = d.Tree.Root()
}
return d.root, nil
}
13 changes: 10 additions & 3 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,14 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) {

codec := NewLeoRSCodec()

edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square
err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
assert.NoError(t, err)

// create a DA header
eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4)

assert.NotNil(t, eds)
dAHeaderRoots, err := eds.getRowRoots()
assert.NoError(t, err)
Expand Down Expand Up @@ -436,10 +442,11 @@ func createTestEdsWithNMT(t *testing.T, codec Codec, shareSize, namespaceSize in
for i, shareValue := range sharesValue {
shares[i] = bytes.Repeat([]byte{byte(shareValue)}, shareSize)
}
edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square

eds, err := ComputeExtendedDataSquare(shares, codec, newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
treeConstructorFn, err := TreeFn("testing-tree")
require.NoError(t, err)

eds, err := ComputeExtendedDataSquare(shares, codec, treeConstructorFn)
require.NoError(t, err)

return eds
Expand Down
36 changes: 32 additions & 4 deletions extendeddatasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,45 @@
type ExtendedDataSquare struct {
*dataSquare
codec Codec
treeName string
originalDataWidth uint
}

func (eds *ExtendedDataSquare) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}{
DataSquare: eds.dataSquare.Flattened(),
Codec: eds.codec.Name(),
Tree: eds.treeName,
})
}
sontrinh16 marked this conversation as resolved.
Show resolved Hide resolved

func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error {
var aux struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}

if err := json.Unmarshal(b, &aux); err != nil {
err := json.Unmarshal(b, &aux)
if err != nil {
return err
}

Check warning on line 44 in extendeddatasquare.go

View check run for this annotation

Codecov / codecov/patch

extendeddatasquare.go#L43-L44

Added lines #L43 - L44 were not covered by tests

var treeConstructor TreeConstructorFn
if aux.Tree == "" {
aux.Tree = DefaultTreeName
}

treeConstructor, err = TreeFn(aux.Tree)
if err != nil {
return err
rootulp marked this conversation as resolved.
Show resolved Hide resolved
}
rootulp marked this conversation as resolved.
Show resolved Hide resolved
importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], NewDefaultTree)

importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], treeConstructor)
if err != nil {
return err
}
Expand All @@ -61,12 +77,18 @@
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
treeName := getTreeNameFromConstructorFn(treeCreatorFn)
if treeName == "" {
return nil, errors.New("tree name not found")
}

Check warning on line 89 in extendeddatasquare.go

View check run for this annotation

Codecov / codecov/patch

extendeddatasquare.go#L88-L89

Added lines #L88 - L89 were not covered by tests

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
err = eds.erasureExtendSquare(codec)
if err != nil {
return nil, err
Expand All @@ -90,12 +112,18 @@
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
treeName := getTreeNameFromConstructorFn(treeCreatorFn)
if treeName == "" {
return nil, errors.New("tree name not found")
}

Check warning on line 124 in extendeddatasquare.go

View check run for this annotation

Codecov / codecov/patch

extendeddatasquare.go#L123-L124

Added lines #L123 - L124 were not covered by tests

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
err = validateEdsWidth(eds.width)
if err != nil {
return nil, err
Expand Down
63 changes: 63 additions & 0 deletions extendeddatasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,69 @@ func TestMarshalJSON(t *testing.T) {
}
}

// TestUnmarshalJSON test the UnmarshalJSON function.
func TestUnmarshalJSON(t *testing.T) {
treeName := "testing_unmarshalJSON_tree"
treeConstructorFn := sudoConstructorFn
err := RegisterTree(treeName, treeConstructorFn)
require.NoError(t, err)

codec := NewLeoRSCodec()
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, treeConstructorFn)
if err != nil {
panic(err)
}

tests := []struct {
name string
malleate func()
expectedTreeName string
cleanUp func()
}{
{
"Tree field exists",
func() {},
treeName,
func() {
cleanUp(treeName)
},
},
{
"Tree field missing",
func() {
// clear the tree name value in the eds before marshal
result.treeName = ""
},
DefaultTreeName,
func() {},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.malleate()
edsBytes, err := json.Marshal(result)
if err != nil {
t.Errorf("failed to marshal EDS: %v", err)
}

var eds ExtendedDataSquare
err = json.Unmarshal(edsBytes, &eds)
if err != nil {
t.Errorf("failed to unmarshal EDS: %v", err)
}
if !reflect.DeepEqual(result.squareRow, eds.squareRow) {
t.Errorf("eds not equal after json marshal/unmarshal")
}
require.Equal(t, test.expectedTreeName, eds.treeName)

test.cleanUp()
})
}
}

func TestNewExtendedDataSquare(t *testing.T) {
t.Run("returns an error if edsWidth is not even", func(t *testing.T) {
edsWidth := uint(1)
Expand Down
78 changes: 55 additions & 23 deletions tree.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package rsmt2d

import (
"crypto/sha256"

"github.com/celestiaorg/merkletree"
"fmt"
"reflect"
"sync"
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
Expand All @@ -22,33 +22,65 @@ type Tree interface {
Root() ([]byte, error)
}

var _ Tree = &DefaultTree{}
// treeFns is a global map used for keeping track of registered tree constructors for JSON serialization
// The keys of this map should be kebab cased. E.g. "default-tree"
staheri14 marked this conversation as resolved.
Show resolved Hide resolved
var treeFns = sync.Map{}

// RegisterTree must be called in the init function
func RegisterTree(treeName string, treeConstructor TreeConstructorFn) error {
sontrinh16 marked this conversation as resolved.
Show resolved Hide resolved
sontrinh16 marked this conversation as resolved.
Show resolved Hide resolved
if _, ok := treeFns.Load(treeName); ok {
return fmt.Errorf("%s already registered", treeName)
}

treeFns.Store(treeName, treeConstructor)

type DefaultTree struct {
*merkletree.Tree
leaves [][]byte
root []byte
return nil
}

func NewDefaultTree(_ Axis, _ uint) Tree {
return &DefaultTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
// TreeFn get tree constructor function by tree name from the global map registry
func TreeFn(treeName string) (TreeConstructorFn, error) {
var treeFn TreeConstructorFn
v, ok := treeFns.Load(treeName)
if !ok {
return nil, fmt.Errorf("%s not registered yet", treeName)
}
treeFn, ok = v.(TreeConstructorFn)
if !ok {
return nil, fmt.Errorf("key %s has invalid interface", treeName)
}

return treeFn, nil
}

func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
// removeTreeFn removes a treeConstructorFn by treeName.
// Only use for test cleanup. Proceed with caution.
func removeTreeFn(treeName string) {
treeFns.Delete(treeName)
}

func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
// Get the tree name by the tree constructor function from the global map registry
// TODO: this code is temporary until all breaking changes is handle here: https://github.com/celestiaorg/rsmt2d/pull/278
func getTreeNameFromConstructorFn(treeConstructor TreeConstructorFn) string {
key := ""
treeFns.Range(func(k, v interface{}) bool {
keyString, ok := k.(string)
if !ok {
// continue checking other key, value
return true
}
d.root = d.Tree.Root()
}
return d.root, nil
treeFn, ok := v.(TreeConstructorFn)
if !ok {
// continue checking other key, value
return true
}

if reflect.DeepEqual(reflect.ValueOf(treeFn), reflect.ValueOf(treeConstructor)) {
key = keyString
return false
}

return true
})

return key
}
Loading
Loading