Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscriber iterator #2039

Draft
wants to merge 6 commits into
base: 0.6.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions faststream/_internal/subscriber/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ def add_call(
middlewares_: Sequence["SubscriberMiddleware[Any]"],
dependencies_: Iterable["Dependant"],
) -> Self: ...

@abstractmethod
def __aiter__(self) -> StreamMessage[MsgType]:
...
21 changes: 20 additions & 1 deletion faststream/kafka/subscriber/usecase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Sequence
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional

import anyio
from aiokafka import ConsumerRecord, TopicPartition
Expand Down Expand Up @@ -179,6 +179,25 @@ async def get_one(
decoder=self._decoder,
)

@override
async def __aiter__(self) -> AsyncIterator["StreamMessage[MsgType]"]: # type: ignore[override]
assert self.consumer, "You should start subscriber at first." # nosec B101
assert ( # nosec B101
not self.calls
), "You can't use `get_one` method if subscriber has registered handlers."

async for raw_message in self.consumer:
context = self._state.get().di_state.context
msg: StreamMessage[MsgType] = await process_msg( # type: ignore[assignment]
msg=raw_message,
middlewares=(
m(raw_message, context=context) for m in self._broker_middlewares
),
parser=self._parser,
decoder=self._decoder,
)
yield msg

def _make_response_publisher(
self,
message: "StreamMessage[Any]",
Expand Down
32 changes: 26 additions & 6 deletions faststream/rabbit/subscriber/usecase.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import asyncio
import contextlib
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
)
from collections.abc import AsyncIterator, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast

import anyio
from typing_extensions import override
Expand Down Expand Up @@ -184,6 +180,30 @@ async def get_one(
)
return msg

@override
async def __aiter__(self) -> AsyncIterator[RabbitMessage]: # type: ignore[override]
assert self._queue_obj, "You should start subscriber at first." # nosec B101
assert ( # nosec B101
not self.calls
), "You can't use iterator method if subscriber has registered handlers."

context = self._state.get().di_state.context

async with self._queue_obj.iterator() as queue_iter:
async for raw_message in queue_iter:
raw_message = cast("IncomingMessage", raw_message)

msg: RabbitMessage = await process_msg( # type: ignore[assignment]
msg=raw_message,
middlewares=(
m(raw_message, context=context)
for m in self._broker_middlewares
),
parser=self._parser,
decoder=self._decoder,
)
yield msg

def _make_response_publisher(
self,
message: "StreamMessage[Any]",
Expand Down
32 changes: 32 additions & 0 deletions tests/brokers/rabbit/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,35 @@ async def handler(msg: RabbitMessage) -> None:
m.mock.assert_not_called()

assert event.is_set()

@pytest.mark.asyncio()
async def test_iteration(self, queue: str, exchange: RabbitExchange) -> None:
expected_messages = ("test_message_1", "test_message_2")
consume_broker = self.get_broker(apply_types=True)

subscirber = consume_broker.subscriber(
queue,
exchange=exchange,
ack_policy=AckPolicy.DO_NOTHING,
)

async def publish_test_message():
for msg in expected_messages:
await br.publish(msg, queue=queue, exchange=exchange)

async with self.patch_broker(consume_broker) as br:
await br.start()
_ = await asyncio.create_task(publish_test_message())

index_message = 0
async for msg in subscirber:

assert msg is not None

result_message = await msg.decode()

assert result_message == expected_messages[index_message]

index_message += 1
if index_message >= len(expected_messages):
break