Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return types to call() methods. #32

Merged
merged 6 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,26 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, NamedTuple, TypedDict
from typing import Any, NamedTuple

from web3.types import ABI

from pypechain.utilities.abi import (
get_abi_items,
get_input_names,
get_input_names_and_values,
get_input_types,
get_output_names,
get_output_names_and_values,
get_output_types,
get_structs_for_abi,
is_abi_constructor,
is_abi_function,
load_abi_from_file,
)
from pypechain.utilities.format import capitalize_first_letter_only
from pypechain.utilities.templates import get_jinja_env


class SignatureData(TypedDict):
"""Define the structure of the signature_datas dictionary"""

input_names_and_types: list[str]
input_names: list[str]
outputs: list[str]


class FunctionData(TypedDict):
"""Define the structure of the function_data dictionary"""

name: str
capitalized_name: str
signature_datas: list[SignatureData]
from pypechain.utilities.types import FunctionData, SignatureData, gather_matching_types


def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
Expand All @@ -53,13 +42,14 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
env = get_jinja_env()
templates = get_templates_for_contract_file(env)

# TODO: add return types to function calls

abi, bytecode = load_abi_from_file(abi_file_path)
function_datas, constructor_data = get_function_datas(abi)
has_overloading = any(len(function_data["signature_datas"]) > 1 for function_data in function_datas.values())
has_bytecode = bool(bytecode)

structs_for_abi = get_structs_for_abi(abi)
structs_used = gather_matching_types(list(function_datas.values()), list(structs_for_abi.keys()))

functions_block = templates.functions_template.render(
abi=abi,
has_overloading=has_overloading,
Expand All @@ -84,6 +74,7 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
# Render the template
return templates.base_template.render(
contract_name=contract_name,
structs_used=structs_used,
has_overloading=has_overloading,
has_bytecode=has_bytecode,
functions_block=functions_block,
Expand Down Expand Up @@ -131,14 +122,14 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData
constructor_data: SignatureData | None = None
for abi_function in get_abi_items(abi):
if is_abi_function(abi_function):
# TODO: investigate better typing here? templete.render expects an object so we'll have
# to convert.
# hanndle constructor
if is_abi_constructor(abi_function):
constructor_data = {
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"input_types": get_input_types(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_names_and_values(abi_function),
}

# handle all other functions
Expand All @@ -147,11 +138,11 @@ def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData
signature_data: SignatureData = {
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"input_types": get_input_types(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_types(abi_function),
}
function_data: FunctionData = {
# TODO: pass a typeguarded ABIFunction that has only required fields?
# name is required in the typeguard. Should be safe to default to empty string.
"name": name,
"capitalized_name": capitalize_first_letter_only(name),
"signature_datas": [signature_data],
Expand Down
3 changes: 2 additions & 1 deletion pypechain/templates/contract.py/base.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ from typing import cast

from eth_typing import ChecksumAddress{% if has_bytecode %}, HexStr{% endif %}
{% if has_bytecode %}from hexbytes import HexBytes{% endif %}
from web3.types import ABI
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams
from web3.contract.contract import Contract, ContractFunction, ContractFunctions
from web3.exceptions import FallbackNotFound
{% if has_overloading %}
from multimethod import multimethod
{% endif %}
{% if structs_used|length > 0 %}from .{{contract_name}}Types import {{ structs_used|join(', ')}}{% endif %}

{{functions_block}}

Expand Down
11 changes: 11 additions & 0 deletions pypechain/templates/contract.py/functions.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ class {{contract_name}}{{function_data.capitalized_name}}ContractFunction(Contra
def __call__(self{% if signature_data.input_names_and_types %}, {{signature_data.input_names_and_types|join(', ')}}{% endif %}) -> "{{contract_name}}{{function_data.capitalized_name}}ContractFunction":{%- if has_overloading %} #type: ignore{% endif %}
super().__call__({{signature_data.input_names|join(', ')}})
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None){% if signature_data.output_types|length == 1 %} -> {{signature_data.output_types[0]}}{% elif signature_data.output_types|length > 1%} -> tuple[{{signature_data.output_types|join(', ')}}]{% endif %}:
{% if signature_data.output_types|length == 1 %}"""returns {{signature_data.output_types[0]}}"""{% elif signature_data.output_types|length > 1%}"""returns ({{signature_data.output_types|join(', ')}})"""{% else %}"""No return value"""{% endif %}
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)


{% endfor %}
{% endfor %}
class {{contract_name}}ContractFunctions(ContractFunctions):
Expand Down
115 changes: 85 additions & 30 deletions pypechain/utilities/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Literal, NamedTuple, Sequence, TypeGuard, cast
from typing import Literal, NamedTuple, Sequence, TypeGuard, cast

from web3 import Web3
from web3.types import ABI, ABIElement, ABIEvent, ABIFunction, ABIFunctionComponents, ABIFunctionParams
Expand All @@ -16,35 +16,6 @@
from pypechain.utilities.types import solidity_to_python_type


class Input(NamedTuple):
"""An input of a function or event."""

internalType: str
name: str
type: str
indexed: bool | None = None


class Output(NamedTuple):
"""An output of a function or event."""

internalType: str
internalType: str
name: str
type: str


class AbiItem(NamedTuple):
"""An item of an ABI, can be an event, function or struct."""

type: str
inputs: List[Input]
stateMutability: str | None = None
anonymous: bool | None = None
name: str | None = None
outputs: List[Output] | None = None


class AbiJson(NamedTuple):
"""A JSON representation of a solidity contract's Application Boundary Interface."""

Expand Down Expand Up @@ -557,6 +528,48 @@ def get_input_names_and_values(function: ABIFunction) -> list[str]:
return _get_names_and_values(function, "inputs")


def get_input_types(function: ABIFunction) -> list[str]:
"""Returns function input type strings for jinja templating.

i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool
flag, bytes extraData)

the following list would be returned: ['str', 'int', 'bool', 'bytes']

Arguments
---------
function : ABIFunction
A web3 dict of an ABI function description.

Returns
-------
list[str]
A list of function python values, i.e. ['str', 'bool']
"""
return _get_param_types(function, "inputs")


def get_output_types(function: ABIFunction) -> list[str]:
"""Returns function output type strings for jinja templating.

i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool
flag, bytes extraData)

the following list would be returned: ['str', 'int', 'bool', 'bytes']

Arguments
---------
function : ABIFunction
A web3 dict of an ABI function description.

Returns
-------
list[str]
A list of function python values, i.e. ['str', 'bool']
sentilesdal marked this conversation as resolved.
Show resolved Hide resolved
"""
return _get_param_types(function, "outputs")


def get_output_names_and_values(function: ABIFunction) -> list[str]:
"""Returns function input name/type strings for jinja templating.

Expand Down Expand Up @@ -610,6 +623,48 @@ def _get_names_and_values(function: ABIFunction, parameters_type: Literal["input
return stringified_function_parameters


def _get_param_types(function: ABIFunction, parameters_type: Literal["inputs", "outputs"]) -> list[str]:
"""Returns function input or output type strings for jinja templating.

i.e. for the solidity function signature: function doThing(address who, uint256 amount, bool
flag, bytes extraData)

the following list would be returned: ['who: str', 'amount: int', 'flag: bool', 'extraData:
bytes']

Arguments
---------
function : ABIFunction
A web3 dict of an ABI function description.
parameters_type : Literal["inputs", "outputs"]
If we are looking at the inputs or outputs of a function.

Returns
-------
list[str]
A list of function parameter python types, i.e. ['str', 'bool']
"""
stringified_function_parameters: list[str] = []
inputs_or_outputs = function.get(parameters_type, [])
inputs_or_outputs = cast(list[ABIFunctionParams], inputs_or_outputs)

for param in inputs_or_outputs:
python_type = get_param_type(param)
stringified_function_parameters.append(f"{python_type}")
return stringified_function_parameters


def get_param_type(param: ABIFunctionParams):
"""Gets the associated python type, including generated dataclasses"""
internal_type = cast(str, param.get("internalType", ""))
# if we find a struct, we'll add it to the dict of StructInfo's
if is_struct(internal_type):
python_type = get_struct_name(param)
else:
python_type = solidity_to_python_type(param.get("type", "unknown"))
return python_type


def get_abi_from_json(json_abi: FoundryJson | SolcJson | ABI) -> ABI:
"""Gets the ABI from a supported json format."""
if is_foundry_json(json_abi):
Expand Down
4 changes: 3 additions & 1 deletion pypechain/utilities/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,6 @@ def apply_black_formatting(code: str, line_length: int = 80) -> str:
try:
return black.format_file_contents(code, fast=False, mode=black.Mode(line_length=line_length))
except ValueError as exc:
raise ValueError(f"cannot format with Black\n code:\n{code}") from exc
print(f"cannot format with Black\n code:\n{code}")
print(f"{exc=}")
return code
50 changes: 50 additions & 0 deletions pypechain/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@


import logging
from typing import TypedDict


class SignatureData(TypedDict):
"""Define the structure of the signature_datas dictionary"""

input_names_and_types: list[str]
input_names: list[str]
input_types: list[str]
outputs: list[str]
output_types: list[str]


class FunctionData(TypedDict):
"""Define the structure of the function_data dictionary"""

name: str
capitalized_name: str
signature_datas: list[SignatureData]


def solidity_to_python_type(solidity_type: str) -> str:
Expand Down Expand Up @@ -82,3 +101,34 @@ def solidity_to_python_type(solidity_type: str) -> str:
logging.warning("Unknown Solidity type: %s", solidity_type)

return solidity_type


def gather_matching_types(function_datas: list[FunctionData], known_types: list[str]) -> list[str]:
"""Gather matching types from inputs and outputs in the function_datas.

Parameters
----------
function_datas : list[FunctionData]
A list of function datas.
known_types : list[str]
A list of known types.

Returns
-------
list[str]
The matching list of types.
"""
matching_types = []

for function_data in function_datas:
for signature_data in function_data["signature_datas"]:
# Check input types
for input_type in signature_data["input_types"]:
if input_type in known_types:
matching_types.append(input_type)
# Check output types
for output_type in signature_data["output_types"]:
if output_type in known_types:
matching_types.append(output_type)

return matching_types
22 changes: 22 additions & 0 deletions snapshots/expected_not_overloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ def __call__(self) -> "OverloadedBalanceOfContractFunction":
super().__call__()
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> int:
"""returns int"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)



class OverloadedBalanceOfWhoContractFunction(ContractFunction):
"""ContractFunction for the balanceOfWho method."""
# super() call methods are generic, while our version adds values & types
Expand All @@ -16,6 +27,17 @@ def __call__(self, who: str) -> "OverloadedBalanceOfWhoContractFunction":
super().__call__(who)
return self

def call(
self,
transaction: TxParams | None = None,
block_identifier: BlockIdentifier = 'latest',
state_override: CallOverride | None = None,
ccip_read_enabled: bool | None = None) -> bool:
"""returns bool"""
return super().call(transaction, block_identifier, state_override, ccip_read_enabled)




class OverloadedContractFunctions(ContractFunctions):
"""ContractFunctions for the Overloaded contract."""
Expand Down
Loading