diff --git a/eventsourcing_eventstoredb/factory.py b/eventsourcing_eventstoredb/factory.py index 7ccbb61..ed32fe6 100644 --- a/eventsourcing_eventstoredb/factory.py +++ b/eventsourcing_eventstoredb/factory.py @@ -44,6 +44,8 @@ def __init__(self, env: Environment): f"'{self.EVENTSTOREDB_ROOT_CERTIFICATES}' " "when connecting to a secure server." ) from e + else: + raise def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder: return EventStoreDBAggregateRecorder( @@ -58,5 +60,6 @@ def process_recorder(self) -> ProcessRecorder: raise NotImplementedError() def __del__(self) -> None: - del self.client - # self.client.close() + if hasattr(self, "client"): + del self.client + # self.client.close() diff --git a/eventsourcing_eventstoredb/recorders.py b/eventsourcing_eventstoredb/recorders.py index 5b68dbe..20be5df 100644 --- a/eventsourcing_eventstoredb/recorders.py +++ b/eventsourcing_eventstoredb/recorders.py @@ -41,7 +41,6 @@ def insert_events( def _insert_events( self, stored_events: List[StoredEvent], **kwargs: Any ) -> Optional[Sequence[int]]: - if self.for_snapshotting: # Protect against appending old snapshot after new. assert len(stored_events) == 1, len(stored_events) @@ -54,7 +53,10 @@ def _insert_events( ) if len(recorded_snapshots) > 0: last_snapshot = recorded_snapshots[0] - if last_snapshot.originator_version > stored_events[0].originator_version: + if ( + last_snapshot.originator_version + > stored_events[0].originator_version + ): return [] else: # Make sure all stored events have same originator ID. @@ -68,7 +70,10 @@ def _insert_events( # Make sure stored events have a gapless sequence of originator_versions. for i in range(1, len(stored_events)): - if stored_events[i].originator_version != i + stored_events[0].originator_version: + if ( + stored_events[i].originator_version + != i + stored_events[0].originator_version + ): raise IntegrityError("Gap detected in originator versions") # Convert StoredEvent objects to NewEvent objects. @@ -151,7 +156,10 @@ def select_events( # noqa: C901 else: if lte is not None: - position = lte + current_position = self.client.get_current_version(stream_name) + if current_position is StreamState.NO_STREAM: + return [] + position = lte = min(current_position, lte) if gt is not None: _limit = max(0, lte - gt) if limit is None: @@ -170,15 +178,14 @@ def select_events( # noqa: C901 else: limit = min(limit, _limit) - try: - recorded_events = self.client.read_stream( - stream_name=stream_name, - stream_position=position, - backwards=desc, - limit=limit if limit is not None else sys.maxsize, - ) - except NotFound: + if limit == 0: return [] + recorded_events = self.client.read_stream( + stream_name=stream_name, + stream_position=position, + backwards=desc, + limit=limit if limit is not None else sys.maxsize, + ) stored_events = [] try: diff --git a/tests/test_application.py b/tests/test_application.py index 2f23f3d..3ffc2a0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -31,8 +31,14 @@ def setUp(self) -> None: def tearDown(self) -> None: Aggregate.INITIAL_VERSION = self.original_initial_version - del os.environ["PERSISTENCE_MODULE"] - del os.environ["EVENTSTOREDB_URI"] + try: + del os.environ["PERSISTENCE_MODULE"] + except KeyError: + pass + try: + del os.environ["EVENTSTOREDB_URI"] + except KeyError: + pass super().tearDown() def test_example_application(self) -> None: @@ -274,5 +280,17 @@ class LoggedEvent(DomainEvent): self.assertEqual(events[0].name, "name1") self.assertEqual(events[1].name, "name2") + def test_construct_without_uri(self) -> None: + del os.environ["EVENTSTOREDB_URI"] + with self.assertRaises(EnvironmentError) as cm: + BankAccounts(env={"IS_SNAPSHOTTING_ENABLED": "y"}) + self.assertIn("EVENTSTOREDB_URI", str(cm.exception)) + + def test_construct_secure_without_root_certificates(self) -> None: + os.environ["EVENTSTOREDB_URI"] = "esdb://localhost" + with self.assertRaises(EnvironmentError) as cm: + BankAccounts(env={"IS_SNAPSHOTTING_ENABLED": "y"}) + self.assertIn("EVENTSTOREDB_ROOT_CERTIFICATES", str(cm.exception)) + del ExampleApplicationTestCase diff --git a/tests/test_docs.py b/tests/test_docs.py index a89218f..48b339c 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -4,7 +4,7 @@ import os import ssl from pathlib import Path -from typing import List +from typing import Any, Dict, List from unittest import TestCase from eventsourcing.utils import clear_topic_cache @@ -132,10 +132,8 @@ def check_code_snippets_in_file(self, doc_path: Path) -> None: # noqa: C901 source = "\n".join(lines) + "\n" - globals = {} - exec( - compile(source=source, filename=doc_path, mode="exec"), globals, globals - ) + globals: Dict[Any, Any] = {} + exec(compile(source=source, filename=doc_path, mode="exec"), globals, globals) # # Write the code into a temp file. # with NamedTemporaryFile("w+") as tempfile: diff --git a/tests/test_recorders.py b/tests/test_recorders.py index fbdf215..2dc09e2 100644 --- a/tests/test_recorders.py +++ b/tests/test_recorders.py @@ -5,6 +5,7 @@ from eventsourcing.persistence import ( AggregateRecorder, ApplicationRecorder, + ProgrammingError, StoredEvent, ) from eventsourcing.tests.persistence import ( @@ -30,6 +31,240 @@ def create_recorder(self) -> AggregateRecorder: def test_insert_and_select(self) -> None: super(TestEventStoreDBAggregateRecorder, self).test_insert_and_select() + # Construct the recorder. + recorder = self.create_recorder() + + # Write three stored events. + originator_id1 = uuid4() + stored_event1 = StoredEvent( + originator_id=originator_id1, + originator_version=self.INITIAL_VERSION, + topic="topic1", + state=b"state1", + ) + stored_event2 = StoredEvent( + originator_id=originator_id1, + originator_version=self.INITIAL_VERSION + 1, + topic="topic2", + state=b"state2", + ) + stored_event3 = StoredEvent( + originator_id=originator_id1, + originator_version=self.INITIAL_VERSION + 2, + topic="topic3", + state=b"state3", + ) + + # Insert three events. + recorder.insert_events([stored_event1, stored_event2, stored_event3]) + + # Select events with gt, lte and limit args. + self.assertEqual( # reads from after start, limited by limit + recorder.select_events(originator_id1, gt=0, lte=30, limit=0), + [], + ) + self.assertEqual( # reads from after start, limited by limit + recorder.select_events(originator_id1, gt=0, lte=30, limit=1), + [stored_event2], + ) + self.assertEqual( # reads from after start, limited by limit + recorder.select_events(originator_id1, gt=0, lte=30, limit=2), + [stored_event2, stored_event3], + ) + self.assertEqual( # reads from after start, limited by lte + recorder.select_events(originator_id1, gt=0, lte=0, limit=10), + [], + ) + self.assertEqual( # reads from after start, limited by lte + recorder.select_events(originator_id1, gt=0, lte=1, limit=10), + [stored_event2], + ) + self.assertEqual( # reads from after start, limited by lte + recorder.select_events(originator_id1, gt=0, lte=2, limit=10), + [stored_event2, stored_event3], + ) + self.assertEqual( # reads from after start, limited by lte + recorder.select_events(originator_id1, gt=1, lte=2, limit=10), + [stored_event3], + ) + self.assertEqual( # reads from after start, limited by lte + recorder.select_events(originator_id1, gt=2, lte=10, limit=10), + [], + ) + + # Select events with lte and limit args. + self.assertEqual( # read limited by limit + recorder.select_events(originator_id1, lte=10, limit=1), + [stored_event1], + ) + self.assertEqual( # read limited by limit + recorder.select_events(originator_id1, lte=10, limit=2), + [stored_event1, stored_event2], + ) + self.assertEqual( # read limited by lte + recorder.select_events(originator_id1, lte=0, limit=10), + [stored_event1], + ) + self.assertEqual( # read limited by lte + recorder.select_events(originator_id1, lte=1, limit=10), + [stored_event1, stored_event2], + ) + self.assertEqual( # read limited by lte + recorder.select_events(originator_id1, lte=10, limit=10), + [stored_event1, stored_event2, stored_event3], + ) + self.assertEqual( # read limited by both lte and limit + recorder.select_events(originator_id1, lte=1, limit=1), + [stored_event1], + ) + + # Select events with desc, gt, lte. + self.assertEqual( # reads from after end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=5, lte=10), + [], + ) + self.assertEqual( # reads from after end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=2, lte=10), + [], + ) + self.assertEqual( # reads from after end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=1, lte=10), + [stored_event3], + ) + self.assertEqual( # reads from before end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=1, lte=1), + [], + ) + self.assertEqual( # reads from before end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=0, lte=1), + [stored_event2], + ) + + # Select events with desc, gt, lte and limit args. + self.assertEqual( # reads from after end, limited by given limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=3, limit=1), + [stored_event3], + ) + self.assertEqual( # reads from end, limited by given limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=2, limit=1), + [stored_event3], + ) + self.assertEqual( # reads from before end, limited by given limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=1, limit=1), + [stored_event2], + ) + + self.assertEqual( # reads from after end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=0, lte=3, limit=10), + [stored_event3, stored_event2], + ) + self.assertEqual( # reads from end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=0, lte=2, limit=10), + [stored_event3, stored_event2], + ) + self.assertEqual( # reads from before end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=0, lte=1, limit=10), + [stored_event2], + ) + + self.assertEqual( # reads from after end, limited by gt and limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=3, limit=2), + [stored_event3, stored_event2], + ) + self.assertEqual( # reads from end, limited by gt and limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=2, limit=2), + [stored_event3, stored_event2], + ) + self.assertEqual( # reads from before end, limited by gt and limit + recorder.select_events(originator_id1, desc=True, gt=0, lte=1, limit=1), + [stored_event2], + ) + + # Select events with desc, lte (NO STREAM). + self.assertEqual( # reads from after end, limited by limit + recorder.select_events(uuid4(), desc=True, lte=10, limit=1), + [], + ) + + # Select events with desc, lte and limit args. + self.assertEqual( # reads from after end, limited by limit + recorder.select_events(originator_id1, desc=True, lte=10, limit=1), + [stored_event3], + ) + self.assertEqual( # reads from end, limited by limit + recorder.select_events(originator_id1, desc=True, lte=2, limit=1), + [stored_event3], + ) + self.assertEqual( # reads from before end, limited by limit + recorder.select_events(originator_id1, desc=True, lte=1, limit=1), + [stored_event2], + ) + self.assertEqual( # reads from before end, limited by start of stream + recorder.select_events(originator_id1, desc=True, lte=1, limit=10), + [stored_event2, stored_event1], + ) + + # Select events with desc, gt + self.assertEqual( # reads until after end + recorder.select_events(originator_id1, desc=True, gt=10), + [], + ) + self.assertEqual( # reads until end + recorder.select_events(originator_id1, desc=True, gt=1), + [stored_event3], + ) + self.assertEqual( # reads until before end + recorder.select_events(originator_id1, desc=True, gt=0), + [stored_event3, stored_event2], + ) + + # Select events with desc, gt (NO STREAM) + self.assertEqual( # reads until before end + recorder.select_events(uuid4(), desc=True, gt=1), + [], + ) + + # Select events with desc, gt, limit + self.assertEqual( # reads until after end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=10, limit=10), + [], + ) + self.assertEqual( # reads until end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=1, limit=10), + [stored_event3], + ) + self.assertEqual( # reads until before end, limited by gt + recorder.select_events(originator_id1, desc=True, gt=0, limit=10), + [stored_event3, stored_event2], + ) + self.assertEqual( # reads until before end, limited by limit + recorder.select_events(originator_id1, desc=True, gt=0, limit=1), + [stored_event3], + ) + self.assertEqual( # reads until before end, limited by gt and limit + recorder.select_events(originator_id1, desc=True, gt=1, limit=1), + [stored_event3], + ) + self.assertEqual( # reads until before end, limited by limit + recorder.select_events(originator_id1, desc=True, gt=0, limit=2), + [stored_event3, stored_event2], + ) + + # Can't store events in more than one stream. + with self.assertRaises(ProgrammingError): + stored_event4 = StoredEvent( + originator_id=uuid4(), + originator_version=self.INITIAL_VERSION, + topic="topic4", + state=b"state4", + ) + stored_event5 = StoredEvent( + originator_id=uuid4(), + originator_version=self.INITIAL_VERSION, + topic="topic5", + state=b"state5", + ) + recorder.insert_events([stored_event4, stored_event5]) class TestEventStoreDBApplicationRecorder(ApplicationRecorderTestCase):