diff --git a/smt/pkg/smt/witness_test.go b/smt/pkg/smt/witness_test.go index 7078fec5ed5..04914a30f01 100644 --- a/smt/pkg/smt/witness_test.go +++ b/smt/pkg/smt/witness_test.go @@ -20,6 +20,8 @@ import ( ) func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { + t.Helper() + contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") balance := uint256.NewInt(1000000000) sKey := libcommon.HexToHash("0x5") @@ -44,46 +46,46 @@ func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { intraBlockState.AddBalance(contract, balance) intraBlockState.SetState(contract, &sKey, *sVal) - if err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter()); err != nil { - t.Errorf("error finalising 1st tx: %v", err) - } - if err := intraBlockState.CommitBlock(&chain.Rules{}, w); err != nil { - t.Errorf("error committing block: %v", err) - } + err := intraBlockState.FinalizeTx(&chain.Rules{}, tds.TrieStateWriter()) + require.NoError(t, err, "error finalising 1st tx") - rl, err := tds.ResolveSMTRetainList() + err = intraBlockState.CommitBlock(&chain.Rules{}, w) + require.NoError(t, err, "error committing block") - if err != nil { - t.Errorf("error resolving state trie: %v", err) - } + rl, err := tds.ResolveSMTRetainList() + require.NoError(t, err, "error resolving state trie") memdb := db.NewMemDb() smtTrie := smt.NewSMT(memdb, false) - smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig()) - smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code)) - err = memdb.AddCode(code) + _, err = smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig()) + require.NoError(t, err) - if err != nil { - t.Errorf("error adding code to memdb: %v", err) - } + err = smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code)) + require.NoError(t, err) + + err = memdb.AddCode(code) + require.NoError(t, err, "error adding code to memdb") storage := make(map[string]string, 0) for i := 0; i < 100; i++ { - k := libcommon.HexToHash(fmt.Sprintf("0x%d", i)) - storage[k.String()] = k.String() + k := libcommon.HexToHash(fmt.Sprintf("0x%d", i)).String() + storage[k] = k } storage[sKey.String()] = sVal.String() - smtTrie.SetContractStorage(contract.String(), storage, nil) + _, err = smtTrie.SetContractStorage(contract.String(), storage, nil) + require.NoError(t, err) return smtTrie, rl } func findNode(t *testing.T, w *trie.Witness, addr libcommon.Address, storageKey libcommon.Hash, nodeType int) []byte { + t.Helper() + for _, operator := range w.Operators { switch op := operator.(type) { case *trie.OperatorSMTLeafValue: @@ -110,23 +112,19 @@ func TestSMTWitnessRetainList(t *testing.T) { sVal := uint256.NewInt(0xdeadbeef) witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) - - if err != nil { - t.Errorf("error building witness: %v", err) - } + require.NoError(t, err, "error building witness") foundCode := findNode(t, witness, contract, libcommon.Hash{}, utils.SC_CODE) foundBalance := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_BALANCE) foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE) foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE) - if foundCode == nil || foundBalance == nil || foundNonce == nil || foundStorage == nil { - t.Errorf("witness does not contain all expected operators") - } + require.NotNil(t, foundCode) + require.NotNil(t, foundBalance) + require.NotNil(t, foundNonce) + require.NotNil(t, foundStorage) - if !bytes.Equal(foundStorage, sVal.Bytes()) { - t.Errorf("witness contains unexpected storage value") - } + require.Equal(t, foundStorage, sVal.Bytes(), "witness contains unexpected storage value") } func TestSMTWitnessRetainListEmptyVal(t *testing.T) { @@ -137,7 +135,8 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) { sKey := libcommon.HexToHash("0x5") // Set nonce to 0 - smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig()) + _, err := smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(0).ToBig()) + require.NoError(t, err) witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) @@ -150,14 +149,13 @@ func TestSMTWitnessRetainListEmptyVal(t *testing.T) { foundNonce := findNode(t, witness, contract, libcommon.Hash{}, utils.KEY_NONCE) foundStorage := findNode(t, witness, contract, sKey, utils.SC_STORAGE) - if foundCode == nil || foundBalance == nil || foundStorage == nil { - t.Errorf("witness does not contain all expected operators") - } + // Code, balance and storage should be present in the witness + require.NotNil(t, foundCode) + require.NotNil(t, foundBalance) + require.NotNil(t, foundStorage) // Nonce should not be in witness - if foundNonce != nil { - t.Errorf("witness contains unexpected operator") - } + require.Nil(t, foundNonce, "witness contains unexpected operator") } // TestWitnessToSMT tests that the SMT built from a witness matches the original SMT @@ -165,19 +163,13 @@ func TestWitnessToSMT(t *testing.T) { smtTrie, rl := prepareSMT(t) witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) - if err != nil { - t.Errorf("error building witness: %v", err) - } + require.NoError(t, err, "error building witness") newSMT, err := smt.BuildSMTfromWitness(witness) - if err != nil { - t.Errorf("error building SMT from witness: %v", err) - } + require.NoError(t, err, "error building SMT from witness") root, err := newSMT.Db.GetLastRoot() - if err != nil { - t.Errorf("error getting last root: %v", err) - } + require.NoError(t, err, "error getting last root from db") // newSMT.Traverse(context.Background(), root, func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { // fmt.Printf("[After] path: %v, hash: %x\n", prefix, libcommon.BigToHash(k.ToBigInt())) @@ -188,9 +180,7 @@ func TestWitnessToSMT(t *testing.T) { require.NoError(t, err, "error getting last root") // assert that the roots are the same - if expectedRoot.Cmp(root) != 0 { - t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes())) - } + require.Equal(t, expectedRoot, root, "SMT root mismatch") } // TestWitnessToSMTStateReader tests that the SMT built from a witness matches the state @@ -200,27 +190,18 @@ func TestWitnessToSMTStateReader(t *testing.T) { sKey := libcommon.HexToHash("0x5") expectedRoot, err := smtTrie.Db.GetLastRoot() - if err != nil { - t.Errorf("error getting last root: %v", err) - } + require.NoError(t, err, "error getting last root") witness, err := smt.BuildWitness(smtTrie, rl, context.Background()) - if err != nil { - t.Errorf("error building witness: %v", err) - } + require.NoError(t, err, "error building witness") newSMT, err := smt.BuildSMTfromWitness(witness) - if err != nil { - t.Errorf("error building SMT from witness: %v", err) - } + require.NoError(t, err, "error building SMT from witness") + root, err := newSMT.Db.GetLastRoot() - if err != nil { - t.Errorf("error building SMT from witness: %v", err) - } + require.NoError(t, err, "error getting the last root from db") - if expectedRoot.Cmp(root) != 0 { - t.Errorf(fmt.Sprintf("SMT root mismatch, expected %x, got %x", expectedRoot.Bytes(), root.Bytes())) - } + require.Equal(t, expectedRoot, root, "SMT root mismatch") contract := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") @@ -237,17 +218,10 @@ func TestWitnessToSMTStateReader(t *testing.T) { require.Equal(t, expectedAcc, newAcc) require.Equal(t, expectedAccCode, newAccCode) - // TODO: @Stefan-Ethernal Check and remove - // // assert that the account code is the same - // if !bytes.Equal(expectedAccCode, newAccCode) { - // t.Error("Account Code Mismatch") - // } // assert that the account code size is the same require.Equal(t, expectedAccCodeSize, newAccCodeSize) // assert that the storage value is the same - if !bytes.Equal(expectedStorageValue, newStorageValue) { - t.Error("Storage Value Mismatch") - } + require.Equal(t, expectedStorageValue, newStorageValue) }