Skip to content

Commit

Permalink
Inters speed up 2 (#1694)
Browse files Browse the repository at this point in the history
* moving from big int usage for storage key calculation

* use common package to avoid big int when parsing an address

* alternative scalar to array using uint64 rather than big int

* remove unused util function
  • Loading branch information
hexoscott authored Jan 29, 2025
1 parent 2a22c67 commit 589e949
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 112 deletions.
8 changes: 4 additions & 4 deletions core/state/trie_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,10 +927,10 @@ func (tds *TrieDbState) ResolveSMTRetainList(inclusion map[libcommon.Address][]l
}

getSMTPath := func(ethAddr string, key string) ([]int, error) {
a := utils.ConvertHexToBigInt(ethAddr)
addr := utils.ScalarToArrayBig(a)

storageKey := utils.KeyContractStorage(addr, key)
storageKey, err := utils.KeyContractStorage(ethAddr, key)
if err != nil {
return nil, err
}

return storageKey.GetPath(), nil
}
Expand Down
16 changes: 8 additions & 8 deletions smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (s *SMT) SetAccountBalance(ethAddr string, balance *big.Int) (*big.Int, err
return nil, err
}

ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
ks := utils.EncodeKeySource(utils.KEY_BALANCE, common.HexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyBalance, ks)
if err != nil {
return nil, err
Expand All @@ -56,7 +56,7 @@ func (s *SMT) SetAccountNonce(ethAddr string, nonce *big.Int) (*big.Int, error)
return nil, err
}

ks := utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
ks := utils.EncodeKeySource(utils.KEY_NONCE, common.HexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyNonce, ks)
if err != nil {
return nil, err
Expand Down Expand Up @@ -90,7 +90,7 @@ func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error {
return err
}

ks := utils.EncodeKeySource(utils.SC_CODE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
ks := utils.EncodeKeySource(utils.SC_CODE, common.HexToAddress(ethAddr), common.Hash{})

err = s.Db.InsertKeySource(keyContractCode, ks)

Expand All @@ -103,7 +103,7 @@ func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error {
return err
}

ks = utils.EncodeKeySource(utils.SC_LENGTH, utils.ConvertHexToAddress(ethAddr), common.Hash{})
ks = utils.EncodeKeySource(utils.SC_LENGTH, common.HexToAddress(ethAddr), common.Hash{})

return s.Db.InsertKeySource(keyContractLength, ks)
}
Expand Down Expand Up @@ -321,12 +321,12 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}
ethAddr := addr.String()
ethAddrBigInt := utils.ConvertHexToBigInt(ethAddr)
ethAddrBigIngArray := utils.ScalarToArrayBig(ethAddrBigInt)

for k, v := range storage {
keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k)
keyStoragePosition, err := utils.KeyContractStorage(addr.String(), k)
if err != nil {
return nil, nil, err
}
valueBigInt := convertStringToBigInt(v)
keysBatchStorage = append(keysBatchStorage, &keyStoragePosition)
if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil {
Expand Down
40 changes: 28 additions & 12 deletions smt/pkg/smt/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,19 @@ func TestVerifyAndGetVal(t *testing.T) {
t.Fatalf("BuildProofs() error = %v", err)
}

contractAddress := libcommon.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624")
a := utils.ConvertHexToBigInt(contractAddress.String())
address := utils.ScalarToArrayBig(a)

smtRoot, _ := smtTrie.RoSMT.DbRo.GetLastRoot()
if err != nil {
t.Fatalf("GetLastRoot() error = %v", err)
}
root := utils.ScalarToRoot(smtRoot)

address := "0x71dd1027069078091B3ca48093B00E4735B20624"

t.Run("Value exists and proof is correct", func(t *testing.T) {
storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
storageProof := smt.FilterProofs(proofs, storageKey)

val, err := smt.VerifyAndGetVal(root, storageProof, storageKey)
Expand All @@ -96,10 +97,13 @@ func TestVerifyAndGetVal(t *testing.T) {

// Fuzz with 1000 non-existent keys
for i := 0; i < 1000; i++ {
nonExistentKey := utils.KeyContractStorage(
nonExistentKey, err := utils.KeyContractStorage(
address,
libcommon.HexToHash(fmt.Sprintf("0xdeadbeefabcd1234%d", i)).String(),
)
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
nonExistentKeys = append(nonExistentKeys, nonExistentKey)
nonExistentKeyPath := nonExistentKey.GetPath()
keyBytes := make([]byte, 0, len(nonExistentKeyPath))
Expand Down Expand Up @@ -132,7 +136,10 @@ func TestVerifyAndGetVal(t *testing.T) {

t.Run("Value doesn't exist but non-existent proof is insufficient", func(t *testing.T) {
nonExistentRl := trie.NewRetainList(0)
nonExistentKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x999").String())
nonExistentKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x999").String())
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
nonExistentKeyPath := nonExistentKey.GetPath()
keyBytes := make([]byte, 0, len(nonExistentKeyPath))

Expand Down Expand Up @@ -165,15 +172,18 @@ func TestVerifyAndGetVal(t *testing.T) {
})

t.Run("Value exists but proof is incorrect (first value corrupted)", func(t *testing.T) {
storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
storageProof := smt.FilterProofs(proofs, storageKey)

// Corrupt the proof by changing a byte
if len(storageProof) > 0 && len(storageProof[0]) > 0 {
storageProof[0][0] ^= 0xFF // Flip all bits in the first byte
}

_, err := smt.VerifyAndGetVal(root, storageProof, storageKey)
_, err = smt.VerifyAndGetVal(root, storageProof, storageKey)

if err == nil {
if err == nil || !strings.Contains(err.Error(), "root mismatch at level 0") {
Expand All @@ -183,7 +193,10 @@ func TestVerifyAndGetVal(t *testing.T) {
})

t.Run("Value exists but proof is incorrect (last value corrupted)", func(t *testing.T) {
storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
storageProof := smt.FilterProofs(proofs, storageKey)

// Corrupt the proof by changing the last byte of the last proof element
Expand All @@ -194,7 +207,7 @@ func TestVerifyAndGetVal(t *testing.T) {
}
}

_, err := smt.VerifyAndGetVal(root, storageProof, storageKey)
_, err = smt.VerifyAndGetVal(root, storageProof, storageKey)

if err == nil {
if err == nil || !strings.Contains(err.Error(), fmt.Sprintf("root mismatch at level %d", len(storageProof)-1)) {
Expand All @@ -204,7 +217,10 @@ func TestVerifyAndGetVal(t *testing.T) {
})

t.Run("Value exists but proof is insufficient", func(t *testing.T) {
storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String())
if err != nil {
t.Fatalf("KeyContractStorage() error = %v", err)
}
storageProof := smt.FilterProofs(proofs, storageKey)

// Modify the proof to claim the value doesn't exist
Expand Down
10 changes: 5 additions & 5 deletions smt/pkg/smt/smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ func (s *SMT) InsertStorage(ethAddr string, storage *map[string]string, chm *map
s.clearUpMutex.Lock()
defer s.clearUpMutex.Unlock()

a := utils.ConvertHexToBigInt(ethAddr)
add := utils.ScalarToArrayBig(a)

or, err := s.getLastRoot()
if err != nil {
return nil, err
Expand All @@ -177,15 +174,18 @@ func (s *SMT) InsertStorage(ethAddr string, storage *map[string]string, chm *map
NewRootScalar: &or,
}
for k := range *storage {
keyStoragePosition := utils.KeyContractStorage(add, k)
keyStoragePosition, err := utils.KeyContractStorage(ethAddr, k)
if err != nil {
return nil, err
}
smtr, err = s.insert(keyStoragePosition, *(*chm)[k], (*vhm)[k], *smtr.NewRootScalar)
if err != nil {
return nil, err
}

sp, _ := utils.StrValToBigInt(k)

ks := utils.EncodeKeySource(utils.SC_STORAGE, utils.ConvertHexToAddress(ethAddr), common.BigToHash(sp))
ks := utils.EncodeKeySource(utils.SC_STORAGE, common.HexToAddress(ethAddr), common.BigToHash(sp))
err = s.Db.InsertKeySource(keyStoragePosition, ks)

if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions smt/pkg/smt/smt_state_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ func (s *SMT) GetAccountCodeHash(address libcommon.Address) (libcommon.Hash, err
// getValue returns the value of a key from SMT by traversing the SMT
func (s *SMT) getValue(key int, address libcommon.Address, storageKey *libcommon.Hash) ([]byte, error) {
var kn utils.NodeKey
var err error

if storageKey == nil {
kn = utils.Key(address.String(), key)
} else {
a := utils.ConvertHexToBigInt(address.String())
add := utils.ScalarToArrayBig(a)

kn = utils.KeyContractStorage(add, storageKey.String())
kn, err = utils.KeyContractStorage(address.String(), storageKey.String())
if err != nil {
return nil, err
}
}

return s.getValueInBytes(kn)
Expand Down
103 changes: 103 additions & 0 deletions smt/pkg/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ func TestScalarToArrayBig(t *testing.T) {
}
}

func TestScalarToArrayUint64(t *testing.T) {
scalar := big.NewInt(0x1234567890ABCDEF)

expected := [8]uint64{
0x90ABCDEF,
0x12345678,
0,
0,
0,
0,
0,
0,
}

result, err := ScalarToArrayUint64(scalar)

if err != nil {
t.Errorf("ScalarToArray = %v; want %v", result, expected)
}

if !reflect.DeepEqual(result, expected) {
t.Errorf("ScalarToArray = %v; want %v", result, expected)
}
}

func BenchmarkScalarToArrayBig(b *testing.B) {
scalar := big.NewInt(0x1234567890ABCDEF)
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -771,3 +796,81 @@ func TestNodeKeyFromPath(t *testing.T) {
}
}
}

func Test_Key(t *testing.T) {
tests := []struct {
input string
output NodeKey
}{
{
input: "0xe859276098f208D003ca6904C6cC26629Ee364Ce",
output: NodeKey{
9755015262748197613,
11140630475045976694,
14930209430661078379,
6319951756608990063,
},
},
}

for _, test := range tests {
result := Key(test.input, 1)
if result != test.output {
t.Errorf("expected %v but got %v", test.output, result)
}
}
}

func TestKeyContractStorage(t *testing.T) {
tests := []struct {
input string
output NodeKey
}{
{
input: "0xe859276098f208D003ca6904C6cC26629Ee364Ce",
output: NodeKey{
9485388526025222793,
2844922146222416636,
12800508867551015356,
9480521524011931274,
},
},
}

for _, test := range tests {
result, err := KeyContractStorage(test.input, "0x1")
if err != nil {
t.Fatal(err)
}
if result != test.output {
t.Errorf("expected %v but got %v", test.output, result)
}
}
}

func TestKeyBig(t *testing.T) {
tests := []struct {
input *big.Int
output NodeKey
}{
{
input: big.NewInt(1092034958475866),
output: NodeKey{
11593000745318970063,
7942385326937081179,
13970824778267919554,
7405798476109204467,
},
},
}

for _, test := range tests {
result, err := KeyBig(test.input, 1)
if err != nil {
t.Fatal(err)
}
if *result != test.output {
t.Errorf("expected %v but got %v", test.output, result)
}
}
}
Loading

0 comments on commit 589e949

Please sign in to comment.