Skip to content

Commit

Permalink
feat: ability to get deployment addresses before deploying (#2433)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Dec 20, 2024
1 parent cb43786 commit 52eb9bb
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/ape/api/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI:

return txn

def get_deployment_address(self, nonce: Optional[int] = None) -> AddressType:
"""
Get a contract address before it is deployed. This is useful
when you need to pass the contract address to another contract
before deploying it.
Args:
nonce (int | None): Optionally provide a nonce. Defaults
the account's current nonce.
Returns:
AddressType: The contract address.
"""
# Use the connected network, if available. Else, default to Ethereum.
ecosystem = (
self.network_manager.active_provider.network.ecosystem
if self.network_manager.active_provider
else self.network_manager.ethereum
)
nonce = self.nonce if nonce is None else nonce
return ecosystem.get_deployment_address(self.address, nonce)


class AccountContainerAPI(BaseInterfaceModel):
"""
Expand Down
12 changes: 12 additions & 0 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,18 @@ def decode_returndata(self, abi: "MethodABI", raw_data: bytes) -> Any:
Any: All of the values returned from the contract function.
"""

@raises_not_implemented
def get_deployment_address( # type: ignore[empty-body]
self,
address: AddressType,
nonce: int,
) -> AddressType:
"""
Calculate the deployment address of a contract before it is deployed.
This is useful if the address is an argument to another contract's deployment
and you have not yet deployed the first contract yet.
"""

def get_network(self, network_name: str) -> "NetworkAPI":
"""
Get the network for the given name.
Expand Down
12 changes: 12 additions & 0 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast

import rlp # type: ignore
from cchecksum import to_checksum_address
from eth_abi import decode, encode
from eth_abi.exceptions import InsufficientDataBytes, NonEmptyPaddingBytes
Expand Down Expand Up @@ -1476,6 +1477,17 @@ def decode_custom_error(
# error never found.
return None

def get_deployment_address(self, address: AddressType, nonce: int) -> AddressType:
"""
Calculate the deployment address of a contract before it is deployed.
This is useful if the address is an argument to another contract's deployment
and you have not yet deployed the first contract yet.
"""
sender_bytes = to_bytes(hexstr=address)
encoded = rlp.encode([sender_bytes, nonce])
address_bytes = keccak(encoded)[12:]
return self.decode_address(address_bytes)


def parse_type(type_: dict[str, Any]) -> Union[str, tuple, list]:
if "tuple" not in type_["type"]:
Expand Down
9 changes: 9 additions & 0 deletions tests/functional/test_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,12 @@ def test_import_account_from_private_key_insecure_passphrase(delete_account_afte
def test_load(account_manager, keyfile_account):
account = account_manager.load(keyfile_account.alias)
assert account == keyfile_account


def test_get_deployment_address(owner, vyper_contract_container):
deployment_address_1 = owner.get_deployment_address()
deployment_address_2 = owner.get_deployment_address(nonce=owner.nonce + 1)
instance_1 = owner.deploy(vyper_contract_container, 490)
assert instance_1.address == deployment_address_1
instance_2 = owner.deploy(vyper_contract_container, 490)
assert instance_2.address == deployment_address_2
6 changes: 6 additions & 0 deletions tests/functional/test_ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,3 +1203,9 @@ def get_calltree(self) -> CallTreeNode:
}
]
assert events == expected


def test_get_deployment_address(ethereum, owner, vyper_contract_container):
actual = ethereum.get_deployment_address(owner.address, owner.nonce)
expected = owner.deploy(vyper_contract_container, 490)
assert actual == expected

0 comments on commit 52eb9bb

Please sign in to comment.