Skip to content

Commit

Permalink
upd code
Browse files Browse the repository at this point in the history
  • Loading branch information
Filipp Makarov authored and Filipp Makarov committed Sep 24, 2024
1 parent 840a645 commit 1feea2f
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cache_forge/solidity-files-cache.json

Large diffs are not rendered by default.

137 changes: 111 additions & 26 deletions src/AssociatedArrayLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,85 @@ library AssociatedArrayLib {
uint256 _spacer;
}

function _length(Array storage s, address account) private view returns (uint256 __length) {
function _slot(Array storage s, address account) private pure returns (bytes32 __slot) {
assembly {
mstore(0x00, account)
mstore(0x20, s.slot)
__length := sload(keccak256(0x00, 0x40))
__slot := keccak256(0x00, 0x40)
}
}

function _length(Array storage s, address account) private view returns (uint256 __length) {
bytes32 slot = _slot(s, account);
assembly {
__length := sload(slot)
}
}

function _get(Array storage s, address account, uint256 index) private view returns (bytes32 value) {
if (index >= _length(s, account)) revert AssociatedArray_OutOfBounds(index);
return _get(_slot(s, account), index);
}

function _get(bytes32 slot, uint256 index) private view returns (bytes32 value) {
assembly {
mstore(0x00, account)
mstore(0x20, s.slot)
value := sload(add(keccak256(0x00, 0x40), mul(0x20, add(index, 1))))
//if (index >= _length(s, account)) revert AssociatedArray_OutOfBounds(index);
if iszero(lt(index, sload(slot))) {
mstore(0, 0x8277484f) // `AssociatedArray_OutOfBounds(uint256)`
mstore(0x20, index)
revert(0x1c, 0x24)
}
value := sload(add(slot, mul(0x20, add(index, 1))))
}
}

function _getAll(Array storage s, address account) private view returns (bytes32[] memory values) {
uint256 __length = _length(s, account);
bytes32 slot = _slot(s, account);
uint256 __length;
assembly {
__length := sload(slot)
}
values = new bytes32[](__length);
for (uint256 i; i < __length; i++) {
values[i] = _get(s, account, i);
values[i] = _get(slot, i);
}
}

// inefficient. complexity = O(n)
// use with caution
// in case of large arrays, consider using EnumerableSet4337 instead
function _contains(Array storage s, address account, bytes32 value) private view returns (bool) {
bytes32 slot = _slot(s, account);
uint256 __length;
assembly {
__length := sload(slot)
}
for (uint256 i; i < __length; i++) {
if (_get(slot, i) == value) {
return true;
}
}
return false;
}

function _set(Array storage s, address account, uint256 index, bytes32 value) private {
if (index >= _length(s, account)) revert AssociatedArray_OutOfBounds(index);
_set(_slot(s, account), index, value);
}

function _set(bytes32 slot, uint256 index, bytes32 value) private {
assembly {
mstore(0x00, account)
mstore(0x20, s.slot)
sstore(add(keccak256(0x00, 0x40), mul(0x20, add(index, 1))), value)
//if (index >= _length(s, account)) revert AssociatedArray_OutOfBounds(index);
if iszero(lt(index, sload(slot))) {
mstore(0, 0x8277484f) // `AssociatedArray_OutOfBounds(uint256)`
mstore(0x20, index)
revert(0x1c, 0x24)
}
sstore(add(slot, mul(0x20, add(index, 1))), value)
}
}

function _push(Array storage s, address account, bytes32 value) private {
bytes32 slot = _slot(s, account);
assembly {
mstore(0x00, account) // store a
mstore(0x20, s.slot) //store x
let slot := keccak256(0x00, 0x40)
// load length (stored @ slot), add 1 to it => index.
// mul index by 0x20 and add it to orig slot to get the next free slot
let index := add(sload(slot), 1)
Expand All @@ -58,24 +98,39 @@ library AssociatedArrayLib {
}

function _pop(Array storage s, address account) private {
uint256 __length = _length(s, account);
bytes32 slot = _slot(s, account);
uint256 __length;
assembly {
__length := sload(slot)
}
if (__length == 0) return;
_set(s, account, __length - 1, 0);
_set(slot, __length - 1, 0);
assembly {
mstore(0x00, account)
mstore(0x20, s.slot)
sstore(keccak256(0x00, 0x40), sub(__length, 1))
sstore(slot, sub(__length, 1))
}
}

function _remove(Array storage s, address account, uint256 index) private {
uint256 __length = _length(s, account);
if (index >= __length) revert AssociatedArray_OutOfBounds(index);
_set(s, account, index, _get(s, account, __length - 1));
bytes32 slot = _slot(s, account);
uint256 __length;
assembly {
mstore(0x00, account)
mstore(0x20, s.slot)
sstore(keccak256(0x00, 0x40), sub(__length, 1))
__length := sload(slot)
if iszero(lt(index, __length)) {
mstore(0, 0x8277484f) // `AssociatedArray_OutOfBounds(uint256)`
mstore(0x20, index)
revert(0x1c, 0x24)
}
}
_set(slot, index, _get(s, account, __length - 1));

assembly {
// clear the last slot
// this is the 'unchecked' version of _set(slot, __length - 1, 0)
// as we use length-1 as index, so the check is excessive.
// also removes extra -1 and +1 operations
sstore(add(slot, mul(0x20, __length)), 0)
// store new length
sstore(slot, sub(__length, 1))
}
}

Expand All @@ -95,6 +150,16 @@ library AssociatedArrayLib {
return _getAll(s._inner, account);
}

function contains(Bytes32Array storage s, address account, bytes32 value) internal view returns (bool) {
return _contains(s._inner, account, value);
}

function add(Bytes32Array storage s, address account, bytes32 value) internal {
if (!_contains(s._inner, account, value)) {
_push(s._inner, account, value);
}
}

function set(Bytes32Array storage s, address account, uint256 index, bytes32 value) internal {
_set(s._inner, account, index, value);
}
Expand Down Expand Up @@ -134,6 +199,16 @@ library AssociatedArrayLib {
return addressArray;
}

function contains(AddressArray storage s, address account, address value) internal view returns (bool) {
return _contains(s._inner, account, bytes32(uint256(uint160(value))));
}

function add(AddressArray storage s, address account, address value) internal {
if (!_contains(s._inner, account, bytes32(uint256(uint160(value))))) {
_push(s._inner, account, bytes32(uint256(uint160(value))));
}
}

function set(AddressArray storage s, address account, uint256 index, address value) internal {
_set(s._inner, account, index, bytes32(uint256(uint160(value))));
}
Expand Down Expand Up @@ -173,6 +248,16 @@ library AssociatedArrayLib {
return uintArray;
}

function contains(UintArray storage s, address account, uint256 value) internal view returns (bool) {
return _contains(s._inner, account, bytes32(value));
}

function add(UintArray storage s, address account, uint256 value) internal {
if (!_contains(s._inner, account, bytes32(value))) {
_push(s._inner, account, bytes32(value));
}
}

function set(UintArray storage s, address account, uint256 index, uint256 value) internal {
_set(s._inner, account, index, bytes32(value));
}
Expand Down
2 changes: 1 addition & 1 deletion src/EnumerableMap.sol → src/EnumerableMap4337.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

pragma solidity ^0.8.20;

import { EnumerableSet } from "./EnumerableSet.sol";
import { EnumerableSet } from "./EnumerableSet4337.sol";

/**
* Fork of OZ's EnumerableSet that makes all storage access ERC-4337 compliant via associated storage
Expand Down
7 changes: 5 additions & 2 deletions src/EnumerableSet.sol → src/EnumerableSet4337.sol
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ library EnumerableSet {
}

function _removeAll(Set storage set, address account) internal {
// get length of the array
uint256 len = _length(set, account);
for (uint256 i; i < len; i++) {
_remove(set, account, _at(set, account, i));
for (uint256 i = 1; i <= len; i++) {
// get last value
bytes32 value = _at(set, account, len - i);
_remove(set, account, value);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/EnumerableMap.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity ^0.8.20;

import "forge-std/Test.sol";
import "src/EnumerableMap.sol";
import "src/EnumerableMap4337.sol";

contract EnumerableMapTest is Test {
using EnumerableMap for EnumerableMap.Bytes32ToBytes32Map;
Expand Down
2 changes: 1 addition & 1 deletion test/EnumerableSet.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity ^0.8.20;

import "forge-std/Test.sol";
import "src/EnumerableSet.sol";
import "src/EnumerableSet4337.sol";

contract EnumerableSetTest is Test {
using EnumerableSet for EnumerableSet.Bytes32Set;
Expand Down

0 comments on commit 1feea2f

Please sign in to comment.