Skip to content

Commit

Permalink
test: use require in witness unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Stefan-Ethernal committed Oct 28, 2024
1 parent d87861f commit 15834d8
Showing 1 changed file with 45 additions and 71 deletions.
116 changes: 45 additions & 71 deletions smt/pkg/smt/witness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
)

func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) {
t.Helper()

contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624")
balance := uint256.NewInt(1000000000)
sKey := libcommon.HexToHash("0x5")
Expand All @@ -44,46 +46,46 @@ func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) {
intraBlockState.AddBalance(contract, balance)
intraBlockState.SetState(contract, &sKey, *sVal)

if err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter()); err != nil {
t.Errorf("error finalising 1st tx: %v", err)
}
if err := intraBlockState.CommitBlock(&chain.Rules{}, w); err != nil {
t.Errorf("error committing block: %v", err)
}
err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter())
require.NoError(t, err, "error finalising 1st tx")

rl, err := tds.ResolveSMTRetainList()
err = intraBlockState.CommitBlock(&chain.Rules{}, w)
require.NoError(t, err, "error committing block")

if err != nil {
t.Errorf("error resolving state trie: %v", err)
}
rl, err := tds.ResolveSMTRetainList()
require.NoError(t, err, "error resolving state trie")

memdb := db.NewMemDb()

smtTrie := smt.NewSMT(memdb, false)

smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig())
smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code))
err = memdb.AddCode(code)
_, err = smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig())
require.NoError(t, err)

if err != nil {
t.Errorf("error adding code to memdb: %v", err)
}
err = smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code))
require.NoError(t, err)

err = memdb.AddCode(code)
require.NoError(t, err, "error adding code to memdb")

storage := make(map[string]string, 0)

for i := 0; i < 100; i++ {
k := libcommon.HexToHash(fmt.Sprintf("0x%d", i))
storage[k.String()] = k.String()
k := libcommon.HexToHash(fmt.Sprintf("0x%d", i)).String()
storage[k] = k
}

storage[sKey.String()] = sVal.String()

smtTrie.SetContractStorage(contract.String(), storage, nil)
_, err = smtTrie.SetContractStorage(contract.String(), storage, nil)
require.NoError(t, err)

return smtTrie, rl
}

func findNode(t *testing.T, w *trie.Witness, addr libcommon.Address, storageKey libcommon.Hash, nodeType int) []byte {
t.Helper()

for _, operator := range w.Operators {
switch op := operator.(type) {
case *trie.OperatorSMTLeafValue:
Expand All @@ -110,23 +112,19 @@ func TestSMTWitnessRetainList(t *testing.T) {
sVal := uint256.NewInt(0xdeadbeef)

witness, err := smt.BuildWitness(smtTrie, rl, context.Background())

if err != nil {
t.Errorf("error building witness: %v", err)
}
require.NoError(t, err, "error building witness")

foundCode := findNode(t, witness, contract, libcommon.Hash{}, utils.SC_CODE)
foundBalance := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_BALANCE)
foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE)
foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE)

if foundCode == nil || foundBalance == nil || foundNonce == nil || foundStorage == nil {
t.Errorf("witness does not contain all expected operators")
}
require.NotNil(t, foundCode)
require.NotNil(t, foundBalance)
require.NotNil(t, foundNonce)
require.NotNil(t, foundStorage)

if !bytes.Equal(foundStorage, sVal.Bytes()) {
t.Errorf("witness contains unexpected storage value")
}
require.Equal(t, foundStorage, sVal.Bytes(), "witness contains unexpected storage value")
}

func TestSMTWitnessRetainListEmptyVal(t *testing.T) {
Expand All @@ -137,7 +135,8 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) {
sKey := libcommon.HexToHash("0x5")

// Set nonce to 0
smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig())
_, err := smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig())
require.NoError(t, err)

witness, err := smt.BuildWitness(smtTrie, rl, context.Background())

Expand All @@ -150,34 +149,27 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) {
foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE)
foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE)

if foundCode == nil || foundBalance == nil || foundStorage == nil {
t.Errorf("witness does not contain all expected operators")
}
// Code, balance and storage should be present in the witness
require.NotNil(t, foundCode)
require.NotNil(t, foundBalance)
require.NotNil(t, foundStorage)

// Nonce should not be in witness
if foundNonce != nil {
t.Errorf("witness contains unexpected operator")
}
require.Nil(t, foundNonce, "witness contains unexpected operator")
}

// TestWitnessToSMT tests that the SMT built from a witness matches the original SMT
func TestWitnessToSMT(t *testing.T) {
smtTrie, rl := prepareSMT(t)

witness, err := smt.BuildWitness(smtTrie, rl, context.Background())
if err != nil {
t.Errorf("error building witness: %v", err)
}
require.NoError(t, err, "error building witness")

newSMT, err := smt.BuildSMTfromWitness(witness)
if err != nil {
t.Errorf("error building SMT from witness: %v", err)
}
require.NoError(t, err, "error building SMT from witness")

root, err := newSMT.Db.GetLastRoot()
if err != nil {
t.Errorf("error getting last root: %v", err)
}
require.NoError(t, err, "error getting last root from db")

// newSMT.Traverse(context.Background(), root, func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) {
// fmt.Printf("[After] path: %v, hash: %x\n", prefix, libcommon.BigToHash(k.ToBigInt()))
Expand All @@ -188,9 +180,7 @@ func TestWitnessToSMT(t *testing.T) {
require.NoError(t, err, "error getting last root")

// assert that the roots are the same
if expectedRoot.Cmp(root) != 0 {
t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes()))
}
require.Equal(t, expectedRoot, root, "SMT root mismatch")
}

// TestWitnessToSMTStateReader tests that the SMT built from a witness matches the state
Expand All @@ -200,27 +190,18 @@ func TestWitnessToSMTStateReader(t *testing.T) {
sKey := libcommon.HexToHash("0x5")

expectedRoot, err := smtTrie.Db.GetLastRoot()
if err != nil {
t.Errorf("error getting last root: %v", err)
}
require.NoError(t, err, "error getting last root")

witness, err := smt.BuildWitness(smtTrie, rl, context.Background())
if err != nil {
t.Errorf("error building witness: %v", err)
}
require.NoError(t, err, "error building witness")

newSMT, err := smt.BuildSMTfromWitness(witness)
if err != nil {
t.Errorf("error building SMT from witness: %v", err)
}
require.NoError(t, err, "error building SMT from witness")

root, err := newSMT.Db.GetLastRoot()
if err != nil {
t.Errorf("error building SMT from witness: %v", err)
}
require.NoError(t, err, "error getting the last root from db")

if expectedRoot.Cmp(root) != 0 {
t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes()))
}
require.Equal(t, expectedRoot, root, "SMT root mismatch")

contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624")

Expand All @@ -237,17 +218,10 @@ func TestWitnessToSMTStateReader(t *testing.T) {
require.Equal(t, expectedAcc, newAcc)

require.Equal(t, expectedAccCode, newAccCode)
// TODO: @Stefan-Ethernal Check and remove
// // assert that the account code is the same
// if !bytes.Equal(expectedAccCode, newAccCode) {
// t.Error("Account Code Mismatch")
// }

// assert that the account code size is the same
require.Equal(t, expectedAccCodeSize, newAccCodeSize)

// assert that the storage value is the same
if !bytes.Equal(expectedStorageValue, newStorageValue) {
t.Error("Storage Value Mismatch")
}
require.Equal(t, expectedStorageValue, newStorageValue)
}

0 comments on commit 15834d8

Please sign in to comment.