From 9fe540fa8dddf8c186d20f300a30f292c2566b29 Mon Sep 17 00:00:00 2001 From: "herve.le-bars" Date: Sun, 27 Oct 2024 21:01:03 +0100 Subject: [PATCH 1/3] refacto: begin refacto uniformisation repositories --- .../infra/repositories/repository_vessel.py | 28 ++++-- backend/bloom/infra/repository.py | 86 +++++++++++++++++++ 2 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 backend/bloom/infra/repository.py diff --git a/backend/bloom/infra/repositories/repository_vessel.py b/backend/bloom/infra/repositories/repository_vessel.py index a0ea4bf0..946dc307 100644 --- a/backend/bloom/infra/repositories/repository_vessel.py +++ b/backend/bloom/infra/repositories/repository_vessel.py @@ -1,5 +1,5 @@ from contextlib import AbstractContextManager -from typing import Any, Generator, Union +from typing import Union from bloom.domain.vessel import Vessel from bloom.domain.metrics import VesselTimeInZone @@ -11,7 +11,24 @@ OrderByRequest, OrderByEnum) +from bloom.infra.repository import GenericRepository, GenericSqlRepository +from abc import ABC, abstractmethod +from bloom.domain.vessel import Vessel +from bloom.infra.database import sql_model + + + + +class VesselRepositoryBase(GenericRepository[Vessel], ABC): + pass + +class VesselRepository(GenericSqlRepository[Vessel,sql_model.Vessel],VesselRepositoryBase): + def __init__(self, session: Session) -> None: + super().__init__(session, sql_model.Vessel) + + +""" class VesselRepository: def __init__( self, @@ -54,9 +71,9 @@ def get_activated_vessel_by_mmsi(self, session: Session, mmsi: int) -> Union[Ves return VesselRepository.map_to_domain(vessel) def get_vessels_list(self, session: Session) -> list[Vessel]: - """ + """""" Liste l'ensemble des vessels actifs - """ + """""" stmt = select(sql_model.Vessel).where(sql_model.Vessel.tracking_activated == True) e = session.execute(stmt).scalars() if not e: @@ -64,9 +81,9 @@ def get_vessels_list(self, session: Session) -> list[Vessel]: return [VesselRepository.map_to_domain(vessel) for vessel in e] def get_all_vessels_list(self, session: Session) -> list[Vessel]: - """ + """""" Liste l'ensemble des vessels actifs ou inactifs - """ + """""" stmt = select(sql_model.Vessel) e = session.execute(stmt).scalars() @@ -212,3 +229,4 @@ def map_to_sql(vessel: Vessel) -> sql_model.Vessel: check=vessel.check, length_class=vessel.length_class, ) + """ \ No newline at end of file diff --git a/backend/bloom/infra/repository.py b/backend/bloom/infra/repository.py new file mode 100644 index 00000000..f54a643c --- /dev/null +++ b/backend/bloom/infra/repository.py @@ -0,0 +1,86 @@ +from typing import TypeVar,Type,Generic, Optional, List, Any +from abc import ABC,abstractmethod +from sqlalchemy import select +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import ScalarSelect, and_, or_ + +SCHEMA = TypeVar("SCHEMA", bound=BaseModel) +MODEL = TypeVar("MODEL", bound=Any) + +class GenericRepository(Generic[SCHEMA], ABC): + + @abstractmethod + def get_by_id(self, id: int) -> Optional[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def list(self, **filters) -> List[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def add(self, record: SCHEMA) -> SCHEMA: + raise NotImplementedError() + + @abstractmethod + def update(self, record: SCHEMA) -> SCHEMA: + # mapper SCHEMA->M missing for record (I know) + self._session.add(record) + self._session.flush() + self._session.refresh(record) + return record + + @abstractmethod + def delete(self, id: int) -> None: + raise NotImplementedError() + +class GenericSqlRepository(GenericRepository[SCHEMA],Generic[SCHEMA,MODEL], ABC): + def __init__(self, + session: Session, + model_cls: Type[MODEL]) -> None: + self._session = session + self._model_cls = model_cls + + def _construct_get_stmt(self, id: int) -> ScalarSelect: + stmt = select(self._model_cls).where(self._model_cls.id == id) + return stmt + + def get_by_id(self, id: int) -> Optional[SCHEMA]: + stmt = self._construct_get_stmt(id) + return self._session.execute(stmt).first() + + def _construct_list_stmt(self, **filters) -> ScalarSelect: + stmt = select(self._model_cls) + where_clauses = [] + for c, v in filters.items(): + if not hasattr(self._model_cls, c): + raise ValueError(f"Invalid column name {c}") + where_clauses.append(getattr(self._model_cls, c) == v) + + if len(where_clauses) == 1: + stmt = stmt.where(where_clauses[0]) + elif len(where_clauses) > 1: + stmt = stmt.where(and_(*where_clauses)) + return stmt + + def list(self, **filters) -> List[SCHEMA]: + stmt = self._construct_list_stmt(**filters) + return self._session.execute(stmt).all() + + def add(self, record: SCHEMA) -> SCHEMA: + self._session.add(record) + self._session.flush() + self._session.refresh(record) + return record + + def update(self, record: SCHEMA) -> SCHEMA: + self._session.add(record) + self._session.flush() + self._session.refresh(record) + return record + + def delete(self, id: int) -> None: + record = self.get_by_id(id) + if record is not None: + self._session.delete(record) + self._session.flush() \ No newline at end of file From a83346ab378cf48585a31589d9087ca5928a6cf1 Mon Sep 17 00:00:00 2001 From: RV Date: Mon, 23 Dec 2024 19:10:03 +0100 Subject: [PATCH 2/3] refacto: repository vessel # Conflicts: # backend/bloom/routers/v1/vessels.py --- backend/bloom/container.py | 1 - .../infra/repositories/repository_vessel.py | 94 ++++++++++++++----- backend/bloom/infra/repository.py | 56 +++++++++-- backend/bloom/routers/v1/vessels.py | 3 +- .../bloom/tasks/load_dim_vessel_from_csv.py | 14 +-- 5 files changed, 124 insertions(+), 44 deletions(-) diff --git a/backend/bloom/container.py b/backend/bloom/container.py index 26efb6cf..0f6dc09c 100644 --- a/backend/bloom/container.py +++ b/backend/bloom/container.py @@ -36,7 +36,6 @@ class UseCases(containers.DeclarativeContainer): vessel_repository = providers.Factory( VesselRepository, - session_factory=db.provided.session, ) alert_repository = providers.Factory( diff --git a/backend/bloom/infra/repositories/repository_vessel.py b/backend/bloom/infra/repositories/repository_vessel.py index 946dc307..461bf0b3 100644 --- a/backend/bloom/infra/repositories/repository_vessel.py +++ b/backend/bloom/infra/repositories/repository_vessel.py @@ -16,16 +16,60 @@ from bloom.domain.vessel import Vessel from bloom.infra.database import sql_model - - +from dependency_injector.providers import Callable class VesselRepositoryBase(GenericRepository[Vessel], ABC): - pass + @abstractmethod + def set_tracking(self, vessel_ids: list[int], tracking_activated: bool, + tracking_status: str) -> None: + raise NotImplementedError() + def check_mmsi_integrity(self) -> list[(int, int)]: + raise NotImplementedError() class VesselRepository(GenericSqlRepository[Vessel,sql_model.Vessel],VesselRepositoryBase): - def __init__(self, session: Session) -> None: - super().__init__(session, sql_model.Vessel) + def __init__(self,session:Session) -> None: + super().__init__(session=session,model_cls=sql_model.Vessel, schema_cls=Vessel) + + def set_tracking(self, vessel_ids: list[int], tracking_activated: bool, + tracking_status: str) -> None: + updates = [{"id": id, "tracking_activated": tracking_activated, "tracking_status": tracking_status} for id in + vessel_ids] + self._session.execute(update(sql_model.Vessel), updates) + + def check_mmsi_integrity(self) -> list[(int, int)]: + # Recherche des valeurs distinctes de MMSI ayant un nombre de résultats actif > 1 + stmt = select(sql_model.Vessel.mmsi, func.count(sql_model.Vessel.id).label("count")).group_by( + sql_model.Vessel.mmsi).having( + func.count(sql_model.Vessel.id) > 1).where( + sql_model.Vessel.tracking_activated == True) + return self._session.execute(stmt).all() + + def map_to_domain(self, model: sql_model.Vessel) -> Vessel: + return Vessel( + id=model.id, + mmsi=model.mmsi, + ship_name=model.ship_name, + width=model.width, + length=model.length, + country_iso3=model.country_iso3, + type=model.type, + imo=model.imo, + cfr=model.cfr, + external_marking=model.external_marking, + ircs=model.ircs, + tracking_activated=model.tracking_activated, + tracking_status=model.tracking_status, + home_port_id=model.home_port_id, + created_at=model.created_at, + updated_at=model.updated_at, + details=model.details, + check=model.check, + length_class=model.length_class, + ) + + def map_to_model(self, schema: Vessel) -> sql_model.Vessel: + return sql_model.Vessel(**schema.__dict__) """ @@ -183,27 +227,27 @@ def check_mmsi_integrity(self, session: Session) -> list[(int, int)]: return session.execute(stmt).all() @staticmethod - def map_to_domain(sql_vessel: sql_model.Vessel) -> Vessel: + def map_to_domain(model: sql_model.Vessel) -> Vessel: return Vessel( - id=sql_vessel.id, - mmsi=sql_vessel.mmsi, - ship_name=sql_vessel.ship_name, - width=sql_vessel.width, - length=sql_vessel.length, - country_iso3=sql_vessel.country_iso3, - type=sql_vessel.type, - imo=sql_vessel.imo, - cfr=sql_vessel.cfr, - external_marking=sql_vessel.external_marking, - ircs=sql_vessel.ircs, - tracking_activated=sql_vessel.tracking_activated, - tracking_status=sql_vessel.tracking_status, - home_port_id=sql_vessel.home_port_id, - created_at=sql_vessel.created_at, - updated_at=sql_vessel.updated_at, - details=sql_vessel.details, - check=sql_vessel.check, - length_class=sql_vessel.length_class, + id=model.id, + mmsi=model.mmsi, + ship_name=model.ship_name, + width=model.width, + length=model.length, + country_iso3=model.country_iso3, + type=model.type, + imo=model.imo, + cfr=model.cfr, + external_marking=model.external_marking, + ircs=model.ircs, + tracking_activated=model.tracking_activated, + tracking_status=model.tracking_status, + home_port_id=model.home_port_id, + created_at=model.created_at, + updated_at=model.updated_at, + details=model.details, + check=model.check, + length_class=model.length_class, ) @staticmethod diff --git a/backend/bloom/infra/repository.py b/backend/bloom/infra/repository.py index f54a643c..194ee92b 100644 --- a/backend/bloom/infra/repository.py +++ b/backend/bloom/infra/repository.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ScalarSelect, and_, or_ +from dependency_injector.providers import Callable SCHEMA = TypeVar("SCHEMA", bound=BaseModel) MODEL = TypeVar("MODEL", bound=Any) @@ -21,25 +22,34 @@ def list(self, **filters) -> List[SCHEMA]: @abstractmethod def add(self, record: SCHEMA) -> SCHEMA: raise NotImplementedError() + + @abstractmethod + def add(self, records: List[SCHEMA]) -> List[SCHEMA]: + raise NotImplementedError() @abstractmethod def update(self, record: SCHEMA) -> SCHEMA: - # mapper SCHEMA->M missing for record (I know) - self._session.add(record) - self._session.flush() - self._session.refresh(record) - return record + raise NotImplementedError() + + @abstractmethod + def update(self, records: List[SCHEMA]) -> List[SCHEMA]: + raise NotImplementedError() @abstractmethod def delete(self, id: int) -> None: raise NotImplementedError() + + def delete(self, ids: List[int]) -> None: + raise NotImplementedError() class GenericSqlRepository(GenericRepository[SCHEMA],Generic[SCHEMA,MODEL], ABC): def __init__(self, session: Session, - model_cls: Type[MODEL]) -> None: + model_cls: Type[MODEL], + schema_cls: Type[SCHEMA]) -> None: self._session = session self._model_cls = model_cls + self._schema_cls = model_cls def _construct_get_stmt(self, id: int) -> ScalarSelect: stmt = select(self._model_cls).where(self._model_cls.id == id) @@ -47,7 +57,7 @@ def _construct_get_stmt(self, id: int) -> ScalarSelect: def get_by_id(self, id: int) -> Optional[SCHEMA]: stmt = self._construct_get_stmt(id) - return self._session.execute(stmt).first() + return self._session.execute(stmt).scalar_one_or_none() def _construct_list_stmt(self, **filters) -> ScalarSelect: stmt = select(self._model_cls) @@ -65,22 +75,50 @@ def _construct_list_stmt(self, **filters) -> ScalarSelect: def list(self, **filters) -> List[SCHEMA]: stmt = self._construct_list_stmt(**filters) - return self._session.execute(stmt).all() + return self._session.execute(stmt).scalars() def add(self, record: SCHEMA) -> SCHEMA: self._session.add(record) self._session.flush() self._session.refresh(record) return record + + def add(self, records: List[SCHEMA]) -> List[SCHEMA]: + [self._session.add(record) for record in records] + self._session.flush() + self._session.refresh(records) + return records def update(self, record: SCHEMA) -> SCHEMA: self._session.add(record) self._session.flush() self._session.refresh(record) return record + + def update(self, records: List[SCHEMA]) -> List[SCHEMA]: + [self._session.add(record) for record in records] + self._session.flush() + self._session.refresh(records) + return records def delete(self, id: int) -> None: record = self.get_by_id(id) if record is not None: self._session.delete(record) - self._session.flush() \ No newline at end of file + self._session.flush() + + def delete(self, ids: List[int]) -> None: + for id in ids: + record = self.get_by_id(id) + if record is not None: + self._session.delete(record) + self._session.flush() + + + @abstractmethod + def map_to_domain(self,model: MODEL) -> SCHEMA: + raise NotImplementedError() + + @abstractmethod + def map_to_model(self,schema: SCHEMA) -> MODEL: + raise NotImplementedError() \ No newline at end of file diff --git a/backend/bloom/routers/v1/vessels.py b/backend/bloom/routers/v1/vessels.py index c33f11e7..1a13e274 100644 --- a/backend/bloom/routers/v1/vessels.py +++ b/backend/bloom/routers/v1/vessels.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, HTTPException, Request from redis import Redis from bloom.config import settings from bloom.container import UseCases @@ -18,7 +18,6 @@ async def list_vessel_tracked(request: Request, # used by @cache key: str = Depends(X_API_KEY_HEADER)): check_apikey(key) use_cases = UseCases() - vessel_repository = use_cases.vessel_repository() db = use_cases.db() with db.session() as session: return vessel_repository.get_vessel_tracked_count(session) diff --git a/backend/bloom/tasks/load_dim_vessel_from_csv.py b/backend/bloom/tasks/load_dim_vessel_from_csv.py index 1d0a1e96..abbce380 100644 --- a/backend/bloom/tasks/load_dim_vessel_from_csv.py +++ b/backend/bloom/tasks/load_dim_vessel_from_csv.py @@ -34,7 +34,6 @@ def map_to_domain(row: pd.Series) -> Vessel: def run(csv_file_name: str) -> None: use_cases = UseCases() - vessel_repository = use_cases.vessel_repository() db = use_cases.db() inserted_ports = [] @@ -43,11 +42,12 @@ def run(csv_file_name: str) -> None: df = pd.read_csv(csv_file_name, sep=",") vessels = df.apply(map_to_domain, axis=1) with db.session() as session: + vessel_repository = use_cases.vessel_repository(session) ports_inserts = [] ports_updates = [] # Pour chaque enregistrement du fichier CSV for vessel in vessels: - if vessel.id and vessel_repository.get_vessel_by_id(session, vessel.id): + if vessel.id and vessel_repository.get_by_id(vessel.id): # si la valeur du champ id n'est pas vide: # rechercher l'enregistrement correspondant dans la table dim_vessel # mettre à jour l'enregistrement à partir des données CSV. @@ -57,20 +57,20 @@ def run(csv_file_name: str) -> None: # insérer les données CSV dans la table dim_vessel; ports_inserts.append(vessel) # Insertions / MAJ en batch - inserted_ports = vessel_repository.batch_create_vessel(session, ports_inserts) - vessel_repository.batch_update_vessel(session, ports_updates) + inserted_ports = vessel_repository.add(ports_inserts) + vessel_repository.List(ports_updates) # En fin de traitement: # les enregistrements de la table dim_vessel pourtant un MMSI absent du fichier CSV sont mis à jour # avec la valeur tracking_activated=FALSE csv_mmsi = list(df['mmsi']) deleted_ports = list( - filter(lambda v: v.mmsi not in csv_mmsi, vessel_repository.get_all_vessels_list(session))) - vessel_repository.set_tracking(session, [v.id for v in deleted_ports], False, + filter(lambda v: v.mmsi not in csv_mmsi, vessel_repository.list())) + vessel_repository.set_tracking([v.id for v in deleted_ports], False, "Suppression logique suite import nouveau fichier CSV") # le traitement vérifie qu'il n'existe qu'un seul enregistrement à l'état tracking_activated==True # pour chaque valeur distincte de MMSI. - integrity_errors = vessel_repository.check_mmsi_integrity(session) + integrity_errors = vessel_repository.check_mmsi_integrity() if not integrity_errors: session.commit() else: From e1e3a068430cbc15eaf6827c8cc784e81034a1f4 Mon Sep 17 00:00:00 2001 From: RV Date: Mon, 23 Dec 2024 19:15:44 +0100 Subject: [PATCH 3/3] ++ # Conflicts: # backend/bloom/infra/database/sql_model.py # backend/bloom/infra/repositories/repository_port.py # backend/bloom/routers/v1/ports.py # backend/bloom/routers/v1/vessels.py --- .env.template | 1 + backend/bloom/config.py | 1 + backend/bloom/infra/database/sql_model.py | 14 +++ .../infra/repositories/repository_port.py | 114 ++++++++---------- backend/bloom/infra/repository.py | 4 +- backend/bloom/routers/v1/ports.py | 4 +- backend/bloom/routers/v1/vessels.py | 1 - backend/bloom/services/geo.py | 4 +- backend/pyproject.toml | 13 ++ docker/Dockerfile | 9 +- 10 files changed, 90 insertions(+), 75 deletions(-) diff --git a/.env.template b/.env.template index 78735791..a801e445 100644 --- a/.env.template +++ b/.env.template @@ -3,6 +3,7 @@ ############################################################################################### # these values are used in the local docker env. You can use "localhost" hostname # if you run the application without docker +POSTGRES_DRIVER=postgresql POSTGRES_HOSTNAME=postgres_bloom POSTGRES_USER=bloom_user POSTGRES_PASSWORD=bloom diff --git a/backend/bloom/config.py b/backend/bloom/config.py index 5ec6cdf5..b1c73395 100644 --- a/backend/bloom/config.py +++ b/backend/bloom/config.py @@ -41,6 +41,7 @@ class Settings(BaseSettings): default=5432) postgres_db:str = Field(min_length=1,max_length=32,pattern=r'^(?:[a-zA-Z]|_)[\w\d_]*$') + postgres_schema:str = Field(default='public') srid: int = Field(default=4326) spire_token:str = Field(default='') data_folder:str=Field(default=str(Path(__file__).parent.parent.parent.joinpath('./data'))) diff --git a/backend/bloom/infra/database/sql_model.py b/backend/bloom/infra/database/sql_model.py index ec5b24ca..c462ff82 100644 --- a/backend/bloom/infra/database/sql_model.py +++ b/backend/bloom/infra/database/sql_model.py @@ -22,6 +22,7 @@ class Vessel(Base): __tablename__ = "dim_vessel" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) mmsi = Column("mmsi", Integer) ship_name = Column("ship_name", String, nullable=False) @@ -47,6 +48,7 @@ class Vessel(Base): class Alert(Base): __tablename__ = "alert" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True, index=True) timestamp = Column("timestamp", DateTime) mpa_id = Column("mpa_id", Integer) @@ -55,6 +57,7 @@ class Alert(Base): class Port(Base): __tablename__ = "dim_port" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True, index=True) name = Column("name", String, nullable=False) locode = Column("locode", String, nullable=False) @@ -71,6 +74,7 @@ class Port(Base): class SpireAisData(Base): __tablename__ = "spire_ais_data" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) spire_update_statement = Column("spire_update_statement", DateTime(timezone=True)) @@ -108,6 +112,7 @@ class SpireAisData(Base): class Zone(Base): __tablename__ = "dim_zone" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) category = Column("category", String, nullable=False) sub_category = Column("sub_category", String) @@ -121,6 +126,7 @@ class Zone(Base): class WhiteZone(Base): __tablename__ = "dim_white_zone" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) geometry = Column("geometry", Geometry(geometry_type="GEOMETRY", srid=settings.srid)) created_at = Column("created_at", DateTime(timezone=True), server_default=func.now()) @@ -129,6 +135,7 @@ class WhiteZone(Base): class VesselPosition(Base): __tablename__ = "vessel_positions" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) accuracy = Column("accuracy", String) @@ -148,6 +155,7 @@ class VesselPosition(Base): class VesselData(Base): __tablename__ = "vessel_data" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) ais_class = Column("ais_class", String) @@ -166,6 +174,7 @@ class VesselData(Base): class VesselVoyage(Base): __tablename__ = "vessel_voyage" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) destination = Column("destination", String) @@ -177,6 +186,7 @@ class VesselVoyage(Base): class Excursion(Base): __tablename__ = "fct_excursion" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) vessel_id = Column("vessel_id", Integer, ForeignKey("dim_vessel.id"), nullable=False) departure_port_id = Column("departure_port_id", Integer, ForeignKey("dim_port.id")) @@ -201,6 +211,7 @@ class Excursion(Base): class Segment(Base): __tablename__ = "fct_segment" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) excursion_id = Column("excursion_id", Integer, ForeignKey("fct_excursion.id"), nullable=False) timestamp_start = Column("timestamp_start", DateTime(timezone=True)) @@ -226,6 +237,7 @@ class Segment(Base): class TaskExecution(Base): __tablename__ = "tasks_executions" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, Identity(), primary_key=True) task_name = Column("task_name", String) point_in_time = Column("point_in_time", DateTime(timezone=True)) @@ -241,6 +253,7 @@ class RelSegmentZone(Base): __tablename__ = "rel_segment_zone" __table_args__ = ( PrimaryKeyConstraint('segment_id', 'zone_id'), + {'schema': settings.postgres_schema} ) segment_id = Column("segment_id", Integer, ForeignKey("fct_segment.id"), nullable=False) zone_id = Column("zone_id", Integer, ForeignKey("dim_zone.id"), nullable=False) @@ -260,6 +273,7 @@ class RelSegmentZone(Base): class MetricsVesselInActivity(Base): __table__ = vessel_in_activity_request + __table_args__ = {'schema': settings.postgres_schema} #vessel_id: Mapped[Optional[int]] #total_time_at_sea: Mapped[Optional[timedelta]] diff --git a/backend/bloom/infra/repositories/repository_port.py b/backend/bloom/infra/repositories/repository_port.py index cc89409e..7aeac4c6 100644 --- a/backend/bloom/infra/repositories/repository_port.py +++ b/backend/bloom/infra/repositories/repository_port.py @@ -13,10 +13,12 @@ from sqlalchemy.orm import Session from bloom.domain.excursion import Excursion +from bloom.infra.repository import GenericRepository, GenericSqlRepository +from abc import ABC, abstractmethod -class PortRepository: - def __init__(self, session_factory: Callable) -> None: - self.session_factory = session_factory +class PortRepositoryBase(GenericRepository[Port], ABC): + def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: + raise NotImplementedError() def get_port_by_id(self, session: Session, port_id: int) -> Union[Port, None]: entity = session.get(sql_model.Port, port_id) @@ -37,67 +39,51 @@ def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: if not q: return [] return [PortRepository.map_to_domain(entity) for entity in q] - - def get_ports_updated_created_after(self, session: Session, created_updated_after: datetime) -> list[Port]: - stmt = select(sql_model.Port).where(or_(sql_model.Port.created_at >= created_updated_after, - sql_model.Port.updated_at >= created_updated_after)) - q = session.execute(stmt).scalars() - if not q: - return [] - return [PortRepository.map_to_domain(entity) for entity in q] - - def update_geometry_buffer(self, session: Session, port_id: int, buffer: Polygon) -> None: - session.execute(update(sql_model.Port), [{"id": port_id, "geometry_buffer": from_shape(buffer)}]) - - def batch_update_geometry_buffer(self, session: Session, id_buffers: list[dict[str, Any]]) -> None: - items = [{"id": item["id"], "geometry_buffer": from_shape(item["geometry_buffer"])} for item in id_buffers] - session.execute(update(sql_model.Port), items) - - def create_port(self, session: Session, port: Port) -> Port: - orm_port = PortRepository.map_to_sql(port) - session.add(orm_port) - return PortRepository.map_to_domain(orm_port) - - def batch_create_port(self, session: Session, ports: list[Port]) -> list[Port]: - orm_list = [PortRepository.map_to_sql(port) for port in ports] - session.add_all(orm_list) - return [PortRepository.map_to_domain(orm) for orm in orm_list] - - def find_port_by_position_in_port_buffer(self, session: Session, position: Point) -> Union[Port, None]: - stmt = select(sql_model.Port).where( - func.ST_contains(sql_model.Port.geometry_buffer, from_shape(position, srid=settings.srid)) == True) - port = session.execute(stmt).scalar() - if not port: - return None - else: - return PortRepository.map_to_domain(port) - - def find_port_by_distance(self, - session: Session, - longitude: float, - latitude: float, - threshold_distance_to_port: float) -> Union[Port, None]: - position = Point(longitude, latitude) - stmt = select(sql_model.Port).where( - and_( - func.ST_within(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_buffer) == True, - func.ST_distance(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_point) < threshold_distance_to_port - ) - ).order_by(asc(func.ST_distance(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_point))) - result = session.execute(stmt).scalars() - return [PortRepository.map_to_domain(e) for e in result] - - def get_closest_port_in_range(self, session: Session, longitude: float, latitude: float, range: float) -> Union[ - tuple[int, float], None]: - res = session.execute(text("""SELECT id,ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) - FROM dim_port WHERE ST_Within(ST_POINT(:longitude,:latitude, 4326),geometry_buffer) = true - AND ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) < :range - ORDER by ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) ASC LIMIT 1"""), - {"longitude": longitude, "latitude": latitude, "range": range}).first() - return res + pass + +# class PortRepository: +# def __init__(self, session_factory: Callable) -> None: +# self.session_factory = session_factory + +# def get_port_by_id(self, session: Session, port_id: int) -> Union[Port, None]: +# entity = session.get(sql_model.Port, port_id) +# if entity is not None: +# return PortRepository.map_to_domain(entity) +# else: +# return None + +# def get_all_ports(self, session: Session) -> List[Port]: +# q = session.query(sql_model.Port) +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: +# stmt = select(sql_model.Port).where(sql_model.Port.geometry_buffer.is_(None)) +# q = session.execute(stmt).scalars() +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def get_ports_updated_created_after(self, session: Session, created_updated_after: datetime) -> list[Port]: +# stmt = select(sql_model.Port).where(or_(sql_model.Port.created_at >= created_updated_after, +# sql_model.Port.updated_at >= created_updated_after)) +# q = session.execute(stmt).scalars() +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def update_geometry_buffer(self, session: Session, port_id: int, buffer: Polygon) -> None: +# session.execute(update(sql_model.Port), [{"id": port_id, "geometry_buffer": from_shape(buffer)}]) + +# def batch_update_geometry_buffer(self, session: Session, id_buffers: list[dict[str, Any]]) -> None: +# items = [{"id": item["id"], "geometry_buffer": from_shape(item["geometry_buffer"])} for item in id_buffers] +# session.execute(update(sql_model.Port), items) + +# def create_port(self, session: Session, port: Port) -> Port: +# orm_port = PortRepository.map_to_sql(port) +# session.add(orm_port) +# return PortRepository.map_to_domain(orm_port) def update_port_has_excursion(self, session : Session, port_id: int ): diff --git a/backend/bloom/infra/repository.py b/backend/bloom/infra/repository.py index 194ee92b..b47f7778 100644 --- a/backend/bloom/infra/repository.py +++ b/backend/bloom/infra/repository.py @@ -57,7 +57,7 @@ def _construct_get_stmt(self, id: int) -> ScalarSelect: def get_by_id(self, id: int) -> Optional[SCHEMA]: stmt = self._construct_get_stmt(id) - return self._session.execute(stmt).scalar_one_or_none() + return self.map_to_domain(self._session.execute(stmt).scalar_one_or_none()) def _construct_list_stmt(self, **filters) -> ScalarSelect: stmt = select(self._model_cls) @@ -75,7 +75,7 @@ def _construct_list_stmt(self, **filters) -> ScalarSelect: def list(self, **filters) -> List[SCHEMA]: stmt = self._construct_list_stmt(**filters) - return self._session.execute(stmt).scalars() + return [ self.map_to_domain(item) for item in self._session.execute(stmt).scalars()] def add(self, record: SCHEMA) -> SCHEMA: self._session.add(record) diff --git a/backend/bloom/routers/v1/ports.py b/backend/bloom/routers/v1/ports.py index 91ded158..20742a84 100644 --- a/backend/bloom/routers/v1/ports.py +++ b/backend/bloom/routers/v1/ports.py @@ -35,7 +35,7 @@ async def get_port(port_id:int, key: str = Depends(X_API_KEY_HEADER)): check_apikey(key) use_cases = UseCases() - port_repository = use_cases.port_repository() db = use_cases.db() with db.session() as session: - return port_repository.get_port_by_id(session,port_id) \ No newline at end of file + port_repository = use_cases.port_repository(session) + return port_repository.get_by_id(port_id) \ No newline at end of file diff --git a/backend/bloom/routers/v1/vessels.py b/backend/bloom/routers/v1/vessels.py index 1a13e274..5230461e 100644 --- a/backend/bloom/routers/v1/vessels.py +++ b/backend/bloom/routers/v1/vessels.py @@ -12,7 +12,6 @@ from fastapi.encoders import jsonable_encoder router = APIRouter() - @router.get("/vessels/trackedCount") async def list_vessel_tracked(request: Request, # used by @cache key: str = Depends(X_API_KEY_HEADER)): diff --git a/backend/bloom/services/geo.py b/backend/bloom/services/geo.py index 81f3d674..976e19a2 100644 --- a/backend/bloom/services/geo.py +++ b/backend/bloom/services/geo.py @@ -29,10 +29,10 @@ def find_positions_in_port_buffer(vessel_positions: List[tuple]) -> List[tuple]: # Get all ports from DataBase use_cases = UseCases() - port_repository = use_cases.port_repository() db = use_cases.db() with db.session() as session: - ports = port_repository.get_all_ports(session) + port_repository = use_cases.port_repository(session) + ports = port_repository.list() df_ports = pd.DataFrame( [[p.id, p.name, p.geometry_buffer] for p in ports], diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 3c5c6ecc..e3a5d535 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -46,6 +46,8 @@ dependencies = [ "fastapi[standard]>=0.115.0,<1.0.0", "uvicorn~=0.32", "redis~=5.0", + "pytest>=8.3.3", + "pytest-env>=1.1.5", ] name = "bloom" version = "0.1.0" @@ -157,3 +159,14 @@ target-version = "py310" [tool.ruff.mccabe] max-complexity = 10 + +[tool.pytest.ini_options] +env = [ + "POSTGRES_DRIVER=sqlite", + "POSTGRES_USER=", + "POSTGRES_PASSWORD=", + "POSTGRES_HOSTNAME=", + "POSTGRES_PORT=", + "POSTGRES_DB=:memory:", +] + diff --git a/docker/Dockerfile b/docker/Dockerfile index e0160a76..c718a293 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,9 +23,9 @@ COPY ./backend/ ${PROJECT_DIR}/backend COPY docker/rsyslog.conf /etc/rsyslog.conf # Install requirements package for python with poetry -ARG POETRY_VERSION=1.8.2 -ENV POETRY_VERSION=${POETRY_VERSION} -RUN pip install --upgrade pip && pip install --user "poetry==$POETRY_VERSION" +#ARG POETRY_VERSION=1.8.2 +#ENV POETRY_VERSION=${POETRY_VERSION} +#RUN pip install --upgrade pip && pip install --user "poetry==$POETRY_VERSION" ENV PATH="${PATH}:/root/.local/bin" COPY ./backend/pyproject.toml ./backend/alembic.ini ./backend/ @@ -37,7 +37,8 @@ ENV UV_PROJECT_ENVIRONMENT=${VIRTUAL_ENV} RUN \ cd backend &&\ uv venv ${VIRTUAL_ENV} &&\ - echo ". ${VIRTUAL_ENV}/bin/activate" >> /root/.bashrc &&\ + echo ". ${VIRTUAL_ENV}/bin/activate" >> ~/.bashrc &&\ + . ${VIRTUAL_ENV}/bin/activate &&\ uv sync # Launch cron services