Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): digest backwards compatibility #451

Merged
merged 10 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/src/uagents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .agent import Agent, Bureau # noqa
from .context import Context # noqa
from .models import Model # noqa
from .models import Model, Field # noqa
Archento marked this conversation as resolved.
Show resolved Hide resolved
from .protocol import Protocol # noqa
2 changes: 1 addition & 1 deletion python/src/uagents/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,6 @@ async def send_sync_message(
if isinstance(response, Envelope):
json_message = response.decode_payload()
if response_type:
return response_type.model_validate_json(json_message)
return response_type.model_validate(json_message)
return json_message
return response
2 changes: 1 addition & 1 deletion python/src/uagents/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Envelope(BaseModel):
target (str): The target's address.
session (UUID4): The session UUID that persists for back-and-forth
dialogues between agents.
schema_digest (str): The schema digest for the enclosed message (alias for protocol).
schema_digest (str): The schema digest for the enclosed message.
protocol_digest (Optional[str]): The digest of the protocol associated with the message
(optional).
payload (Optional[str]): The encoded message payload of the envelope (optional).
Expand Down
25 changes: 16 additions & 9 deletions python/src/uagents/models.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import hashlib
import json
from typing import Type, Union
from typing import Any, Type, Union

from pydantic import BaseModel
from pydantic.v1 import BaseModel, Field # noqa


# reverting back to pydantic.v1 BaseModel for backwards compatibility
class Model(BaseModel):
@classmethod
def model_json_schema(cls) -> str:
return cls.schema_json()

def model_dump_json(self) -> str:
return self.json()
Archento marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def model_validate(cls, obj: Any) -> "Model":
return cls.parse_obj(obj)

@staticmethod
def build_schema_digest(model: Union["Model", Type["Model"]]) -> str:
schema = model.model_json_schema()
digest = (
hashlib.sha256(json.dumps(schema, sort_keys=True).encode("utf8"))
.digest()
.hex()
)
schema = model.schema_json(indent=None, sort_keys=True)
qati marked this conversation as resolved.
Show resolved Hide resolved
digest = hashlib.sha256(schema.encode("utf8")).digest().hex()

return f"model:{digest}"

Expand Down
2 changes: 1 addition & 1 deletion python/src/uagents/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def manifest(self) -> Dict[str, Any]:

for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.model_json_schema()}
{"digest": schema_digest, "schema": model.schema()}
Archento marked this conversation as resolved.
Show resolved Hide resolved
Archento marked this conversation as resolved.
Show resolved Hide resolved
)

for request, responses in self._replies.items():
Expand Down
36 changes: 36 additions & 0 deletions python/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import unittest
from enum import Enum
from typing import List, Literal, Optional

from uagents import Model

Expand All @@ -25,6 +27,40 @@ class SuperImportantCheck(Model):

self.assertEqual(result, target_digest, "Digest mismatch")

def test_calculate_nested_digest_backcompat(self):
Archento marked this conversation as resolved.
Show resolved Hide resolved
"""
Test the digest calculation of nested models.
"""

class UAgentResponseType(Enum):
FINAL = "final"
ERROR = "error"
VALIDATION_ERROR = "validation_error"
SELECT_FROM_OPTIONS = "select_from_options"
FINAL_OPTIONS = "final_options"

class KeyValue(Model):
key: str
value: str

class UAgentResponse(Model):
version: Literal["v1"] = "v1"
type: UAgentResponseType
request_id: Optional[str]
agent_address: Optional[str]
message: Optional[str]
options: Optional[List[KeyValue]]
verbose_message: Optional[str]
verbose_options: Optional[List[KeyValue]]

target_digest = (
"model:cf0d1367c5f9ed8a269de559b2fbca4b653693bb8315d47eda146946a168200e"
)

result = Model.build_schema_digest(UAgentResponse)

self.assertEqual(result, target_digest, "Digest mismatch")


if __name__ == "__main__":
unittest.main()
31 changes: 30 additions & 1 deletion python/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
from typing import Callable

from uagents import Agent, Model, Protocol
from uagents import Agent, Field, Model, Protocol


class Message(Model):
Expand Down Expand Up @@ -80,3 +80,32 @@ def test_protocol_to_include(self):
self.assertEqual(len(models), 2)
self.assertEqual(len(unsigned_msg_handlers), 1)
self.assertEqual(len(signed_msg_handlers), 1)

def test_calculate_protocol_digest_backcompat(self):
# default protocol
proto = Protocol()

digest = proto.manifest()["metadata"]["digest"]
target_digest = (
"proto:a98290009c0891bc431c5159357074527d10eff6b2e86a61fcf7721b472f1125"
)
self.assertEqual(digest, target_digest, "Digest mismatch")

# non-empty protocol
proto = Protocol(name="SampleProtocol", version="0.1.0")

class SampleMessageResponse(Model):
field_2: str = Field(description="Field 2 description")

class SampleMessage(Model):
field_1: int = Field(description="Field 1 description")

@proto.on_message(model=SampleMessage, replies=SampleMessageResponse)
async def handle_query_request():
pass

digest = proto.manifest()["metadata"]["digest"]
target_digest = (
"proto:75259efe00580e5987363935b9180773293970a59463fecc61a97412dd25a1c6"
)
self.assertEqual(digest, target_digest, "Digest mismatch")
Loading