Skip to content

Commit

Permalink
feat: inherit all Enums from str to make JSON serialization possible (e…
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 16, 2024
1 parent 999ee71 commit 2aca1ac
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
4 changes: 2 additions & 2 deletions aidial_sdk/chat_completion/enums.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from enum import Enum


class FinishReason(Enum):
class FinishReason(str, Enum):
STOP = "stop"
LENGTH = "length"
FUNCTION_CALL = "function_call"
TOOL_CALLS = "tool_calls"
CONTENT_FILTER = "content_filter"


class Status(Enum):
class Status(str, Enum):
COMPLETED = "completed"
FAILED = "failed"
2 changes: 1 addition & 1 deletion aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ToolCall(ExtraForbidModel):
function: FunctionCall


class Role(Enum):
class Role(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
Expand Down
19 changes: 19 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json

from aidial_sdk.chat_completion import Message, Role


def test_message_ser():
msg_obj = Message(role=Role.SYSTEM, content="test")
actual_dict = msg_obj.dict(exclude_none=True)
expected_dict = {"role": "system", "content": "test"}

assert json.loads(json.dumps(actual_dict)) == expected_dict


def test_message_deser():
msg_dict = {"role": "system", "content": "test"}
actual_obj = Message.parse_raw(json.dumps(msg_dict))
expected_obj = Message(role=Role.SYSTEM, content="test")

assert actual_obj == expected_obj

0 comments on commit 2aca1ac

Please sign in to comment.