diff --git a/pypechain/render/contract.py b/pypechain/render/contract.py index c4da34e8..87373ab9 100644 --- a/pypechain/render/contract.py +++ b/pypechain/render/contract.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, NamedTuple, TypedDict +from typing import Any, NamedTuple from web3.types import ABI @@ -10,29 +10,18 @@ get_abi_items, get_input_names, get_input_names_and_values, + get_input_types, get_output_names, + get_output_names_and_values, + get_output_types, + get_structs_for_abi, is_abi_constructor, is_abi_function, load_abi_from_file, ) from pypechain.utilities.format import capitalize_first_letter_only from pypechain.utilities.templates import get_jinja_env - - -class SignatureData(TypedDict): - """Define the structure of the signature_datas dictionary""" - - input_names_and_types: list[str] - input_names: list[str] - outputs: list[str] - - -class FunctionData(TypedDict): - """Define the structure of the function_data dictionary""" - - name: str - capitalized_name: str - signature_datas: list[SignatureData] +from pypechain.utilities.types import FunctionData, SignatureData, gather_matching_types def render_contract_file(contract_name: str, abi_file_path: Path) -> str: @@ -53,13 +42,14 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str: env = get_jinja_env() templates = get_templates_for_contract_file(env) - # TODO: add return types to function calls - abi, bytecode = load_abi_from_file(abi_file_path) function_datas, constructor_data = get_function_datas(abi) has_overloading = any(len(function_data["signature_datas"]) > 1 for function_data in function_datas.values()) has_bytecode = bool(bytecode) + structs_for_abi = get_structs_for_abi(abi) + structs_used = gather_matching_types(list(function_datas.values()), list(structs_for_abi.keys())) + functions_block = templates.functions_template.render( abi=abi, has_overloading=has_overloading, @@ -84,6 +74,7 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str: # Render the template return templates.base_template.render( contract_name=contract_name, + structs_used=structs_used, has_overloading=has_overloading, has_bytecode=has_bytecode, functions_block=functions_block, @@ -131,14 +122,14 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData constructor_data: SignatureData | None = None for abi_function in get_abi_items(abi): if is_abi_function(abi_function): - # TODO: investigate better typing here? templete.render expects an object so we'll have - # to convert. # hanndle constructor if is_abi_constructor(abi_function): constructor_data = { "input_names_and_types": get_input_names_and_values(abi_function), "input_names": get_input_names(abi_function), + "input_types": get_input_types(abi_function), "outputs": get_output_names(abi_function), + "output_types": get_output_names_and_values(abi_function), } # handle all other functions @@ -147,11 +138,11 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData signature_data: SignatureData = { "input_names_and_types": get_input_names_and_values(abi_function), "input_names": get_input_names(abi_function), + "input_types": get_input_types(abi_function), "outputs": get_output_names(abi_function), + "output_types": get_output_types(abi_function), } function_data: FunctionData = { - # TODO: pass a typeguarded ABIFunction that has only required fields? - # name is required in the typeguard. Should be safe to default to empty string. "name": name, "capitalized_name": capitalize_first_letter_only(name), "signature_datas": [signature_data], diff --git a/pypechain/templates/contract.py/base.py.jinja2 b/pypechain/templates/contract.py/base.py.jinja2 index ffd4f95b..3855d642 100644 --- a/pypechain/templates/contract.py/base.py.jinja2 +++ b/pypechain/templates/contract.py/base.py.jinja2 @@ -19,12 +19,13 @@ from typing import cast from eth_typing import ChecksumAddress{% if has_bytecode %}, HexStr{% endif %} {% if has_bytecode %}from hexbytes import HexBytes{% endif %} -from web3.types import ABI +from web3.types import ABI, BlockIdentifier, CallOverride, TxParams from web3.contract.contract import Contract, ContractFunction, ContractFunctions from web3.exceptions import FallbackNotFound {% if has_overloading %} from multimethod import multimethod {% endif %} +{% if structs_used|length > 0 %}from .{{contract_name}}Types import {{ structs_used|join(', ')}}{% endif %} {{functions_block}} diff --git a/pypechain/templates/contract.py/functions.py.jinja2 b/pypechain/templates/contract.py/functions.py.jinja2 index 894295b6..2f9dbb2c 100644 --- a/pypechain/templates/contract.py/functions.py.jinja2 +++ b/pypechain/templates/contract.py/functions.py.jinja2 @@ -13,6 +13,17 @@ class {{contract_name}}{{function_data.capitalized_name}}ContractFunction(Contra def __call__(self{% if signature_data.input_names_and_types %}, {{signature_data.input_names_and_types|join(', ')}}{% endif %}) -> "{{contract_name}}{{function_data.capitalized_name}}ContractFunction":{%- if has_overloading %} #type: ignore{% endif %} super().__call__({{signature_data.input_names|join(', ')}}) return self + + def call( + self, + transaction: TxParams | None = None, + block_identifier: BlockIdentifier = 'latest', + state_override: CallOverride | None = None, + ccip_read_enabled: bool | None = None){% if signature_data.output_types|length == 1 %} -> {{signature_data.output_types[0]}}{% elif signature_data.output_types|length > 1%} -> tuple[{{signature_data.output_types|join(', ')}}]{% endif %}: + {% if signature_data.output_types|length == 1 %}"""returns {{signature_data.output_types[0]}}"""{% elif signature_data.output_types|length > 1%}"""returns ({{signature_data.output_types|join(', ')}})"""{% else %}"""No return value"""{% endif %} + return super().call(transaction, block_identifier, state_override, ccip_read_enabled) + + {% endfor %} {% endfor %} class {{contract_name}}ContractFunctions(ContractFunctions): diff --git a/pypechain/utilities/abi.py b/pypechain/utilities/abi.py index 68b6b291..98286f3d 100644 --- a/pypechain/utilities/abi.py +++ b/pypechain/utilities/abi.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import List, Literal, NamedTuple, Sequence, TypeGuard, cast +from typing import Literal, NamedTuple, Sequence, TypeGuard, cast from web3 import Web3 from web3.types import ABI, ABIElement, ABIEvent, ABIFunction, ABIFunctionComponents, ABIFunctionParams @@ -16,35 +16,6 @@ from pypechain.utilities.types import solidity_to_python_type -class Input(NamedTuple): - """An input of a function or event.""" - - internalType: str - name: str - type: str - indexed: bool | None = None - - -class Output(NamedTuple): - """An output of a function or event.""" - - internalType: str - internalType: str - name: str - type: str - - -class AbiItem(NamedTuple): - """An item of an ABI, can be an event, function or struct.""" - - type: str - inputs: List[Input] - stateMutability: str | None = None - anonymous: bool | None = None - name: str | None = None - outputs: List[Output] | None = None - - class AbiJson(NamedTuple): """A JSON representation of a solidity contract's Application Boundary Interface.""" @@ -557,6 +528,48 @@ def get_input_names_and_values(function: ABIFunction) -> list[str]: return _get_names_and_values(function, "inputs") +def get_input_types(function: ABIFunction) -> list[str]: + """Returns function input type strings for jinja templating. + + i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool + flag, bytes extraData) + + the following list would be returned: ['str', 'int', 'bool', 'bytes'] + + Arguments + --------- + function : ABIFunction + A web3 dict of an ABI function description. + + Returns + ------- + list[str] + A list of function python values, i.e. ['str', 'bool'] + """ + return _get_param_types(function, "inputs") + + +def get_output_types(function: ABIFunction) -> list[str]: + """Returns function output type strings for jinja templating. + + i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool + flag, bytes extraData) + + the following list would be returned: ['str', 'int', 'bool', 'bytes'] + + Arguments + --------- + function : ABIFunction + A web3 dict of an ABI function description. + + Returns + ------- + list[str] + A list of function python values, i.e. ['str', 'bool'] + """ + return _get_param_types(function, "outputs") + + def get_output_names_and_values(function: ABIFunction) -> list[str]: """Returns function input name/type strings for jinja templating. @@ -610,6 +623,48 @@ def _get_names_and_values(function: ABIFunction, parameters_type: Literal["input return stringified_function_parameters +def _get_param_types(function: ABIFunction, parameters_type: Literal["inputs", "outputs"]) -> list[str]: + """Returns function input or output type strings for jinja templating. + + i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool + flag, bytes extraData) + + the following list would be returned: ['who: str', 'amount: int', 'flag: bool', 'extraData: + bytes'] + + Arguments + --------- + function : ABIFunction + A web3 dict of an ABI function description. + parameters_type : Literal["inputs", "outputs"] + If we are looking at the inputs or outputs of a function. + + Returns + ------- + list[str] + A list of function parameter python types, i.e. ['str', 'bool'] + """ + stringified_function_parameters: list[str] = [] + inputs_or_outputs = function.get(parameters_type, []) + inputs_or_outputs = cast(list[ABIFunctionParams], inputs_or_outputs) + + for param in inputs_or_outputs: + python_type = get_param_type(param) + stringified_function_parameters.append(f"{python_type}") + return stringified_function_parameters + + +def get_param_type(param: ABIFunctionParams): + """Gets the associated python type, including generated dataclasses""" + internal_type = cast(str, param.get("internalType", "")) + # if we find a struct, we'll add it to the dict of StructInfo's + if is_struct(internal_type): + python_type = get_struct_name(param) + else: + python_type = solidity_to_python_type(param.get("type", "unknown")) + return python_type + + def get_abi_from_json(json_abi: FoundryJson | SolcJson | ABI) -> ABI: """Gets the ABI from a supported json format.""" if is_foundry_json(json_abi): diff --git a/pypechain/utilities/format.py b/pypechain/utilities/format.py index 3ac9d59a..3f4f0951 100644 --- a/pypechain/utilities/format.py +++ b/pypechain/utilities/format.py @@ -125,4 +125,6 @@ def apply_black_formatting(code: str, line_length: int = 80) -> str: try: return black.format_file_contents(code, fast=False, mode=black.Mode(line_length=line_length)) except ValueError as exc: - raise ValueError(f"cannot format with Black\n code:\n{code}") from exc + print(f"cannot format with Black\n code:\n{code}") + print(f"{exc=}") + return code diff --git a/pypechain/utilities/types.py b/pypechain/utilities/types.py index 6f0aaa38..c639e450 100644 --- a/pypechain/utilities/types.py +++ b/pypechain/utilities/types.py @@ -2,6 +2,25 @@ import logging +from typing import TypedDict + + +class SignatureData(TypedDict): + """Define the structure of the signature_datas dictionary""" + + input_names_and_types: list[str] + input_names: list[str] + input_types: list[str] + outputs: list[str] + output_types: list[str] + + +class FunctionData(TypedDict): + """Define the structure of the function_data dictionary""" + + name: str + capitalized_name: str + signature_datas: list[SignatureData] def solidity_to_python_type(solidity_type: str) -> str: @@ -82,3 +101,34 @@ def solidity_to_python_type(solidity_type: str) -> str: logging.warning("Unknown Solidity type: %s", solidity_type) return solidity_type + + +def gather_matching_types(function_datas: list[FunctionData], known_types: list[str]) -> list[str]: + """Gather matching types from inputs and outputs in the function_datas. + + Parameters + ---------- + function_datas : list[FunctionData] + A list of function datas. + known_types : list[str] + A list of known types. + + Returns + ------- + list[str] + The matching list of types. + """ + matching_types = [] + + for function_data in function_datas: + for signature_data in function_data["signature_datas"]: + # Check input types + for input_type in signature_data["input_types"]: + if input_type in known_types: + matching_types.append(input_type) + # Check output types + for output_type in signature_data["output_types"]: + if output_type in known_types: + matching_types.append(output_type) + + return matching_types diff --git a/snapshots/expected_not_overloading.py b/snapshots/expected_not_overloading.py index c49f1571..3a1ae0a7 100644 --- a/snapshots/expected_not_overloading.py +++ b/snapshots/expected_not_overloading.py @@ -7,6 +7,17 @@ def __call__(self) -> "OverloadedBalanceOfContractFunction": super().__call__() return self + def call( + self, + transaction: TxParams | None = None, + block_identifier: BlockIdentifier = 'latest', + state_override: CallOverride | None = None, + ccip_read_enabled: bool | None = None) -> int: + """returns int""" + return super().call(transaction, block_identifier, state_override, ccip_read_enabled) + + + class OverloadedBalanceOfWhoContractFunction(ContractFunction): """ContractFunction for the balanceOfWho method.""" # super() call methods are generic, while our version adds values & types @@ -16,6 +27,17 @@ def __call__(self, who: str) -> "OverloadedBalanceOfWhoContractFunction": super().__call__(who) return self + def call( + self, + transaction: TxParams | None = None, + block_identifier: BlockIdentifier = 'latest', + state_override: CallOverride | None = None, + ccip_read_enabled: bool | None = None) -> bool: + """returns bool""" + return super().call(transaction, block_identifier, state_override, ccip_read_enabled) + + + class OverloadedContractFunctions(ContractFunctions): """ContractFunctions for the Overloaded contract.""" diff --git a/snapshots/expected_overloading.py b/snapshots/expected_overloading.py index f2a3f1bf..b45a8bc3 100644 --- a/snapshots/expected_overloading.py +++ b/snapshots/expected_overloading.py @@ -8,11 +8,33 @@ def __call__(self) -> "OverloadedBalanceOfContractFunction": super().__call__() return self + def call( + self, + transaction: TxParams | None = None, + block_identifier: BlockIdentifier = 'latest', + state_override: CallOverride | None = None, + ccip_read_enabled: bool | None = None) -> int: + """returns int""" + return super().call(transaction, block_identifier, state_override, ccip_read_enabled) + + + def __call__(self, who: str) -> "OverloadedBalanceOfContractFunction": super().__call__(who) return self + def call( + self, + transaction: TxParams | None = None, + block_identifier: BlockIdentifier = 'latest', + state_override: CallOverride | None = None, + ccip_read_enabled: bool | None = None) -> int: + """returns int""" + return super().call(transaction, block_identifier, state_override, ccip_read_enabled) + + + class OverloadedContractFunctions(ContractFunctions): """ContractFunctions for the Overloaded contract."""