diff --git a/faststream/_internal/subscriber/proto.py b/faststream/_internal/subscriber/proto.py index cb24b32295..4a4aed8808 100644 --- a/faststream/_internal/subscriber/proto.py +++ b/faststream/_internal/subscriber/proto.py @@ -91,3 +91,7 @@ def add_call( middlewares_: Sequence["SubscriberMiddleware[Any]"], dependencies_: Iterable["Dependant"], ) -> Self: ... + + @abstractmethod + def __aiter__(self) -> StreamMessage[MsgType]: + ... diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index e96489c1e4..05322795ee 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Iterable, Sequence from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional import anyio from aiokafka import ConsumerRecord, TopicPartition @@ -179,6 +179,25 @@ async def get_one( decoder=self._decoder, ) + @override + async def __aiter__(self) -> AsyncIterator["StreamMessage[MsgType]"]: # type: ignore[override] + assert self.consumer, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + async for raw_message in self.consumer: + context = self._state.get().di_state.context + msg: StreamMessage[MsgType] = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=( + m(raw_message, context=context) for m in self._broker_middlewares + ), + parser=self._parser, + decoder=self._decoder, + ) + yield msg + def _make_response_publisher( self, message: "StreamMessage[Any]", diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 5f08df4571..d90223fee2 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,11 +1,7 @@ import asyncio import contextlib -from collections.abc import Iterable, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Optional, -) +from collections.abc import AsyncIterator, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast import anyio from typing_extensions import override @@ -184,6 +180,30 @@ async def get_one( ) return msg + @override + async def __aiter__(self) -> AsyncIterator[RabbitMessage]: # type: ignore[override] + assert self._queue_obj, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use iterator method if subscriber has registered handlers." + + context = self._state.get().di_state.context + + async with self._queue_obj.iterator() as queue_iter: + async for raw_message in queue_iter: + raw_message = cast("IncomingMessage", raw_message) + + msg: RabbitMessage = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=( + m(raw_message, context=context) + for m in self._broker_middlewares + ), + parser=self._parser, + decoder=self._decoder, + ) + yield msg + def _make_response_publisher( self, message: "StreamMessage[Any]", diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 45b0ad0026..050c546f7b 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -423,3 +423,35 @@ async def handler(msg: RabbitMessage) -> None: m.mock.assert_not_called() assert event.is_set() + + @pytest.mark.asyncio() + async def test_iteration(self, queue: str, exchange: RabbitExchange) -> None: + expected_messages = ("test_message_1", "test_message_2") + consume_broker = self.get_broker(apply_types=True) + + subscirber = consume_broker.subscriber( + queue, + exchange=exchange, + ack_policy=AckPolicy.DO_NOTHING, + ) + + async def publish_test_message(): + for msg in expected_messages: + await br.publish(msg, queue=queue, exchange=exchange) + + async with self.patch_broker(consume_broker) as br: + await br.start() + _ = await asyncio.create_task(publish_test_message()) + + index_message = 0 + async for msg in subscirber: + + assert msg is not None + + result_message = await msg.decode() + + assert result_message == expected_messages[index_message] + + index_message += 1 + if index_message >= len(expected_messages): + break