diff --git a/README.md b/README.md index c945a2e..953d77b 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,6 @@ across mobile, web, and desktop environments. | Quality Control | - | ✅ | ✅ | | Billing and Invoicing | - | ✅ | ✅ | | Yard Management | - | ✅ | ✅ | -| Security and Compliance | ✅ | ✅ | ✅ | | System Administration | - | ✅ | ✅ | | Offline Database Management | - | - | ✅ | | Advanced Reporting | - | - | ✅ | diff --git a/desktop_app/resources/config.json b/desktop_app/resources/config.json index f885013..f4e7caa 100644 --- a/desktop_app/resources/config.json +++ b/desktop_app/resources/config.json @@ -12,9 +12,9 @@ "request_timeout": 10, "font": "Candara", "font_size": "Large", - "icons_path": "resources/icons", "styles_path": "resources/styles", "templates_path": "resources/templates", + "translations_path": "resources/translations", "organization_domain": "nexusware.com" } \ No newline at end of file diff --git a/desktop_app/src/main.py b/desktop_app/src/main.py index a376d7d..272dffc 100644 --- a/desktop_app/src/main.py +++ b/desktop_app/src/main.py @@ -5,6 +5,7 @@ from PySide6.QtWidgets import QApplication, QMessageBox from requests import HTTPError +from desktop_app.src.ui.components.error_dialog import global_exception_handler from public_api.api import APIClient from public_api.api import UsersAPI from services.authentication import AuthenticationService @@ -16,118 +17,100 @@ from utils.logger import setup_logger -def load_stylesheet(filename): - file = QFile(filename) - if file.open(QFile.ReadOnly | QFile.Text): - stream = QTextStream(file) - return stream.readAll() - return "" +class AppContext: + def __init__(self): + self.config_manager = ConfigManager() + self.logger = setup_logger("nexusware") + self.api_client = APIClient(base_url=self.config_manager.get("api_base_url", + "http://127.0.0.1:8000/api/v1")) + self.users_api = UsersAPI(self.api_client) + self.auth_service = AuthenticationService(self.users_api) + self.offline_manager = OfflineManager("offline_data.db") + self.update_manager = UpdateManager(self.config_manager) + + def initialize_app(self): + app = QApplication(sys.argv) + app.setApplicationName("NexusWare WMS") + app.setOrganizationName("NexusWare") + app.setOrganizationDomain(self.config_manager.get("organization_domain", "nexusware.com")) + + QDir.addSearchPath("icons", self.config_manager.get("icons_path", "resources/icons")) + QDir.addSearchPath("styles", self.config_manager.get("styles_path", "resources/styles")) + QDir.addSearchPath("templates", self.config_manager.get("templates_path", + "resources/templates")) + QDir.addSearchPath("translations", self.config_manager.get("translations_path", + "resources/translations")) + + app_icon = QIcon("icons:app_icon.png") + app.setWindowIcon(app_icon) + self.apply_appearance_settings(app) + + language = self.config_manager.get("language", "English") + translator = QTranslator() + if language != "English": + if translator.load(f"translations:{language.lower()}.qm"): + app.installTranslator(translator) + + return app + + def apply_appearance_settings(self, app): + theme = self.config_manager.get("theme", "light") + stylesheet = self.load_stylesheet(f"styles:{theme}_theme.qss") + app.setStyleSheet(stylesheet) + + font_family = self.config_manager.get("font", "Arial") + font_size_name = self.config_manager.get("font_size", "Medium") + font = QFont(font_family) + + if font_size_name == "Small": + font.setPointSize(8) + elif font_size_name == "Medium": + font.setPointSize(10) + elif font_size_name == "Large": + font.setPointSize(12) + + app.setFont(font) + + def load_stylesheet(self, filename): + file = QFile(filename) + if file.open(QFile.ReadOnly | QFile.Text): + stream = QTextStream(file) + return stream.readAll() + return "" + + def create_and_show_main_window(self): + user_permissions = self.users_api.get_current_user_permissions() + main_window = MainWindow(api_client=self.api_client, + config_manager=self.config_manager, + permission_manager=user_permissions) + main_window.show() + return main_window -def apply_appearance_settings(app, config_manager): - # Apply theme - theme = config_manager.get("theme", "light") - stylesheet = load_stylesheet(f"styles:{theme}_theme.qss") - app.setStyleSheet(stylesheet) +def main(): + app_context = AppContext() + app_context.logger.info("Starting NexusWare WMS") - # Apply font - font_family = config_manager.get("font", "Arial") - font_size_name = config_manager.get("font_size", "Medium") - font = QFont(font_family) + app = app_context.initialize_app() - if font_size_name == "Small": - font.setPointSize(8) - elif font_size_name == "Medium": - font.setPointSize(10) - elif font_size_name == "Large": - font.setPointSize(12) + sys.excepthook = global_exception_handler(app_context) - app.setFont(font) + app_context.offline_manager.clear_all_actions() + if app_context.config_manager.get("auto_update", True) and app_context.update_manager.check_for_updates(): + app_context.update_manager.perform_update() -def main(): - # Load configuration - config_manager = ConfigManager() - - # Set up logging - logger = setup_logger("nexusware") - logger.info("Starting NexusWare WMS") - - # Initialize the application - app = QApplication(sys.argv) - app.setApplicationName("NexusWare WMS") - app.setOrganizationName("NexusWare") - app.setOrganizationDomain(config_manager.get("organization_domain", - "nexusware.com")) - - # Adding resource path - QDir.addSearchPath("icons", config_manager.get("icons_path", - "resources/icons")) - QDir.addSearchPath("styles", config_manager.get("styles_path", - "resources/styles")) - QDir.addSearchPath("templates", config_manager.get("templates_path", - "resources/templates")) - - # Set application icon - app_icon = QIcon("icons:app_icon.png") - app.setWindowIcon(app_icon) - - # Apply appearance settings - apply_appearance_settings(app, config_manager) - - # Load language - language = config_manager.get("language", "English") - translator = QTranslator() - if language != "English": - if translator.load(f"resources/translations/{language.lower()}.qm"): - app.installTranslator(translator) - - # Initialize API client - api_client = APIClient(base_url=config_manager.get("api_base_url", - "http://127.0.0.1:8000/api/v1")) - - # Initialize services - users_api = UsersAPI(api_client) - auth_service = AuthenticationService(users_api) - offline_manager = OfflineManager("offline_data.db") - - # TODO: Implement the UpdateManager class correctly - offline_manager.clear_all_actions() - update_manager = UpdateManager(config_manager) - - # Check for updates - if config_manager.get("auto_update", True) and update_manager.check_for_updates(): - update_manager.perform_update() - - # Show login dialog - login_dialog = LoginDialog(auth_service) + login_dialog = LoginDialog(app_context.auth_service) if login_dialog.exec() != LoginDialog.Accepted: sys.exit(0) - # Set up main window - user_permissions = users_api.get_current_user_permissions() - main_window = MainWindow(api_client=api_client, config_manager=config_manager, permission_manager=user_permissions) - - def handle_auth_error(): - QMessageBox.warning(None, "Authentication Error", - "Your session has expired. Please log in again.") - main_window.close() - if login_dialog.exec() == LoginDialog.Accepted: - main_window.show() - else: - sys.exit(0) - - # Show main window - main_window.show() - - # Start the event loop with error handling try: - sys.exit(app.exec()) + main_window = app_context.create_and_show_main_window() # noqa except HTTPError as e: - if e.response.status_code == 403: - handle_auth_error() - else: - raise + print(e) + QMessageBox.critical(None, "Error", str(e)) + + sys.exit(app.exec()) if __name__ == "__main__": diff --git a/desktop_app/src/services/authentication.py b/desktop_app/src/services/authentication.py index e184394..6e1d5c8 100644 --- a/desktop_app/src/services/authentication.py +++ b/desktop_app/src/services/authentication.py @@ -1,26 +1,33 @@ from public_api.api import UsersAPI from public_api.shared_schemas import ( - UserCreate, UserUpdate, UserSanitizedWithRole, Token, Message, User + UserCreate, UserUpdate, UserSanitized, Token, Message ) + class AuthenticationService: def __init__(self, users_api: UsersAPI): self.users_api = users_api - def login(self, email: str, password: str) -> Token: - return self.users_api.login(email, password) + def login(self, username: str, password: str) -> Token: + return self.users_api.login(username, password) - def login_2fa(self, email: str, password: str, two_factor_code: str) -> Token: - return self.users_api.login_2fa(email, password, two_factor_code) + def login_2fa(self, username: str, password: str, two_factor_code: str) -> Token: + return self.users_api.login_2fa(username, password, two_factor_code) - def register(self, user_data: UserCreate) -> User: + def register(self, user_data: UserCreate) -> UserSanitized: return self.users_api.register(user_data) def reset_password(self, email: str) -> Message: return self.users_api.reset_password(email) - def get_current_user(self) -> UserSanitizedWithRole: + def refresh_token(self) -> Token: + return self.users_api.refresh_token() + + def get_current_user(self) -> UserSanitized: return self.users_api.get_current_user() - def update_current_user(self, user_data: UserUpdate) -> User: - return self.users_api.update_current_user(user_data) \ No newline at end of file + def update_current_user(self, user_data: UserUpdate) -> UserSanitized: + return self.users_api.update_current_user(user_data) + + def change_password(self, current_password: str, new_password: str) -> Message: + return self.users_api.change_password(current_password, new_password) diff --git a/desktop_app/src/ui/components/__init__.py b/desktop_app/src/ui/components/__init__.py index ea03da2..5f80fb8 100644 --- a/desktop_app/src/ui/components/__init__.py +++ b/desktop_app/src/ui/components/__init__.py @@ -16,7 +16,7 @@ MessageBox, FileDialog ) -from .error_dialog import ErrorDialog +from .error_dialog import DetailedErrorDialog, global_exception_handler from .inventory_widget_dialogs import ( InventoryDialog, AdjustmentDialog @@ -38,10 +38,10 @@ "ProgressDialog", "MessageBox", "FileDialog", - "ErrorDialog", "InventoryDialog", "AdjustmentDialog", "OrderDialog", "OrderDetailsDialog", - "ShippingDialog" + "ShippingDialog", + "DetailedErrorDialog", ] diff --git a/desktop_app/src/ui/components/error_dialog.py b/desktop_app/src/ui/components/error_dialog.py index e95c428..924687b 100644 --- a/desktop_app/src/ui/components/error_dialog.py +++ b/desktop_app/src/ui/components/error_dialog.py @@ -1,19 +1,72 @@ +import json + from PySide6.QtCore import Qt -from PySide6.QtWidgets import QDialog, QVBoxLayout, QLabel, QPushButton +from PySide6.QtWidgets import (QDialog, QPushButton, QVBoxLayout, QTextEdit, + QLabel, QDialogButtonBox, QApplication) +from requests.exceptions import HTTPError -class ErrorDialog(QDialog): - def __init__(self, error_message, parent=None): - super().__init__(parent) +class DetailedErrorDialog(QDialog): + def __init__(self, exctype, value, traceback, app_context): + super().__init__() + self.app_context = app_context self.setWindowTitle("Error") - self.setFixedSize(300, 150) + self.setMinimumSize(400, 300) + self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint | Qt.WindowMinimizeButtonHint) layout = QVBoxLayout(self) - self.error_label = QLabel(error_message) - self.error_label.setWordWrap(True) # Wrap long error messages - layout.addWidget(self.error_label) + error_label = QLabel("An error has occurred.") + error_label.setStyleSheet("font-weight: bold; color: red;") + layout.addWidget(error_label) + + detailed_text = f"Error Type: {exctype.__name__}\n" + detailed_text += f"Error Message: {str(value)}\n\n" + + if isinstance(value, HTTPError): + response = value.response + detailed_text += f"Status Code: {response.status_code}\n" + detailed_text += f"URL: {response.url}\n" + detailed_text += "Response Headers:\n" + for key, value in response.headers.items(): + detailed_text += f" {key}: {value}\n" + detailed_text += "\nResponse Content:\n" + try: + content = json.loads(response.content) + detailed_text += json.dumps(content, indent=2) + except json.JSONDecodeError: + detailed_text += response.text + + text_edit = QTextEdit() + text_edit.setPlainText(detailed_text) + text_edit.setReadOnly(True) + layout.addWidget(text_edit) + + button_box = QDialogButtonBox(QDialogButtonBox.Ok) + button_box.accepted.connect(self.accept) + layout.addWidget(button_box) + + relogin_button = QPushButton("Re-login") + relogin_button.clicked.connect(self.relogin) + button_box.addButton(relogin_button, QDialogButtonBox.ActionRole) + + def relogin(self): + self.accept() + main_window = QApplication.activeWindow() + if main_window: + main_window.close() + + from desktop_app.src.ui import LoginDialog + login_dialog = LoginDialog(self.app_context.auth_service) + if login_dialog.exec() == LoginDialog.Accepted: + self.app_context.create_and_show_main_window() + else: + QApplication.quit() + + +def global_exception_handler(app_context): + def handler(exctype, value, traceback): + error_dialog = DetailedErrorDialog(exctype, value, traceback, app_context) + error_dialog.exec() - self.ok_button = QPushButton("OK") - self.ok_button.clicked.connect(self.accept) - layout.addWidget(self.ok_button, alignment=Qt.AlignCenter) + return handler diff --git a/desktop_app/src/ui/components/order_view_dialogs.py b/desktop_app/src/ui/components/order_view_dialogs.py index ae1327b..0f2780c 100644 --- a/desktop_app/src/ui/components/order_view_dialogs.py +++ b/desktop_app/src/ui/components/order_view_dialogs.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from PySide6.QtWidgets import (QVBoxLayout, QTableWidget, QTableWidgetItem, QDialog, QComboBox, @@ -15,7 +14,7 @@ class OrderDialog(QDialog): def __init__(self, orders_api: OrdersAPI, customers_api: CustomersAPI, products_api: ProductsAPI, - order_data: Optional[OrderWithDetails] = None, parent=None): + order_data: OrderWithDetails | None = None, parent=None): super().__init__(parent) self.orders_api = orders_api self.customers_api = customers_api diff --git a/desktop_app/src/ui/customer_view.py b/desktop_app/src/ui/customer_view.py index 5bff118..a3a23be 100644 --- a/desktop_app/src/ui/customer_view.py +++ b/desktop_app/src/ui/customer_view.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List, Optional from PySide6.QtCore import Signal from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QTableWidget, QTableWidgetItem, @@ -67,7 +66,7 @@ def refresh_customers(self): customers = self.customers_api.get_customers() self.update_table(customers) - def update_table(self, customers: List[Customer]): + def update_table(self, customers: list[Customer]): self.customers_table.setRowCount(len(customers)) self.customers_table.verticalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) for row, customer in enumerate(customers): @@ -139,7 +138,7 @@ def delete_customer(self, customer: Customer): class CustomerDialog(QDialog): - def __init__(self, customers_api: CustomersAPI, customer_data: Optional[Customer] = None, parent=None): + def __init__(self, customers_api: CustomersAPI, customer_data: Customer | None = None, parent=None): super().__init__(parent) self.customers_api = customers_api self.customer_data = customer_data diff --git a/desktop_app/src/ui/main_window.py b/desktop_app/src/ui/main_window.py index cea4b29..e505cb0 100644 --- a/desktop_app/src/ui/main_window.py +++ b/desktop_app/src/ui/main_window.py @@ -1,7 +1,6 @@ from PySide6.QtCore import Qt, QSize from PySide6.QtGui import QIcon -from PySide6.QtWidgets import QMainWindow, QTabWidget, QVBoxLayout, QWidget, QStatusBar, QMessageBox, QToolButton, \ - QMenu, QSizePolicy +from PySide6.QtWidgets import QMainWindow, QTabWidget, QVBoxLayout, QWidget, QStatusBar, QMessageBox, QToolButton from desktop_app.src.utils import ConfigManager from public_api.api import APIClient @@ -38,7 +37,7 @@ def __init__(self, api_client: APIClient, config_manager: ConfigManager, def init_ui(self): self.setWindowTitle("NexusWare WMS") - self.setWindowIcon(QIcon("resources/icons/app_icon.png")) + self.setWindowIcon(QIcon("icons:app_icon.png")) self.setMinimumSize(1200, 800) central_widget = QWidget() diff --git a/desktop_app/src/ui/notification_center.py b/desktop_app/src/ui/notification_center.py index 5d38efc..c4daa96 100644 --- a/desktop_app/src/ui/notification_center.py +++ b/desktop_app/src/ui/notification_center.py @@ -1,7 +1,7 @@ from PySide6.QtCore import Qt, QTimer, QDateTime +from PySide6.QtGui import QColor, QBrush from PySide6.QtWidgets import (QVBoxLayout, QWidget, QTableWidget, QTableWidgetItem, QPushButton, QHeaderView, QAbstractItemView) -from PySide6.QtGui import QColor, QBrush, QFont from public_api.api import APIClient, NotificationsAPI @@ -89,4 +89,4 @@ def mark_as_read(self, notification_id): def mark_all_as_read(self): self.notifications_api.mark_all_as_read() - self.fetch_notifications() \ No newline at end of file + self.fetch_notifications() diff --git a/desktop_app/src/ui/product_view.py b/desktop_app/src/ui/product_view.py index c750684..f95a565 100644 --- a/desktop_app/src/ui/product_view.py +++ b/desktop_app/src/ui/product_view.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List, Optional from PySide6.QtCore import Signal from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QTableWidget, QTableWidgetItem, @@ -84,7 +83,7 @@ def refresh_products(self): products = self.products_api.get_products(product_filter=filter_params) self.update_table(products) - def update_table(self, items: List[ProductWithCategoryAndInventory]): + def update_table(self, items: list[ProductWithCategoryAndInventory]): self.table.setRowCount(len(items)) self.table.verticalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) for row, item in enumerate(items): @@ -185,10 +184,9 @@ def delete_product(self, product_id): QMessageBox.critical(self, "Error", f"Failed to delete product: {str(e)}") - class ProductDialog(QDialog): def __init__(self, products_api: ProductsAPI, categories_api: ProductCategoriesAPI, - product_data: Optional[ProductWithCategoryAndInventory] = None, parent=None): + product_data: ProductWithCategoryAndInventory | None = None, parent=None): super().__init__(parent) self.products_api = products_api self.categories_api = categories_api diff --git a/desktop_app/src/ui/search_filter.py b/desktop_app/src/ui/search_filter.py index 66e2f42..0e2dce9 100644 --- a/desktop_app/src/ui/search_filter.py +++ b/desktop_app/src/ui/search_filter.py @@ -6,18 +6,14 @@ QDialog, QPushButton, QLabel) from desktop_app.src.ui.components import StyledButton -from public_api.api import APIClient -from public_api.api import ProductsAPI, CustomersAPI, OrdersAPI -from public_api.shared_schemas import ProductFilter, CustomerFilter +from public_api.api import APIClient, SearchAPI class AdvancedSearchDialog(QDialog): def __init__(self, api_client: APIClient, parent=None): super().__init__(parent) self.api_client = api_client - self.products_api = ProductsAPI(api_client) - self.customers_api = CustomersAPI(api_client) - self.orders_api = OrdersAPI(api_client) + self.search_api = SearchAPI(api_client) self.init_ui() def init_ui(self): @@ -31,7 +27,7 @@ def init_ui(self): self.search_input = QLineEdit() self.search_input.setPlaceholderText("Enter search term") self.search_type = QComboBox() - self.search_type.addItems(["Products", "Customers"]) + self.search_type.addItems(["Products", "Orders"]) self.search_button = StyledButton("Search") self.search_button.clicked.connect(self.perform_search) search_layout.addWidget(QLabel("Search:")) @@ -55,9 +51,9 @@ def perform_search(self): search_type = self.search_type.currentText() if search_type == "Products": - results = self.products_api.get_products(product_filter=ProductFilter(name=search_term)) - elif search_type == "Customers": - results = self.customers_api.get_customers(customer_filter=CustomerFilter(name=search_term)) + results = self.search_api.search_products(q=search_term) + elif search_type == "Orders": + results = self.search_api.search_orders(q=search_term) else: results = [] diff --git a/desktop_app/src/ui/shipment_view.py b/desktop_app/src/ui/shipment_view.py index 58dee15..fdf3321 100644 --- a/desktop_app/src/ui/shipment_view.py +++ b/desktop_app/src/ui/shipment_view.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List, Optional from PySide6.QtCore import Signal, QDate from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QTableWidget, QTableWidgetItem, @@ -80,7 +79,7 @@ def refresh_shipments(self): shipments = self.shipments_api.get_shipments(filter_params=filter_params) self.update_table(shipments) - def update_table(self, shipments: List[Shipment]): + def update_table(self, shipments: list[Shipment]): self.table.setRowCount(len(shipments)) for row, shipment in enumerate(shipments): shipment_details = self.shipments_api.get_shipment_with_details(shipment.id) @@ -184,7 +183,7 @@ def generate_label(self, shipment: Shipment): class ShipmentDialog(QDialog): def __init__(self, shipments_api: ShipmentsAPI, orders_api: OrdersAPI, - carriers_api: CarriersAPI, shipment_data: Optional[Shipment] = None, parent=None): + carriers_api: CarriersAPI, shipment_data: Shipment | None = None, parent=None): super().__init__(parent) self.shipments_api = shipments_api self.orders_api = orders_api diff --git a/desktop_app/src/ui/supplier_view.py b/desktop_app/src/ui/supplier_view.py index 16de23f..ce72f00 100644 --- a/desktop_app/src/ui/supplier_view.py +++ b/desktop_app/src/ui/supplier_view.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from PySide6.QtCore import Signal from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QTableWidget, QTableWidgetItem, QHeaderView, QDialog, QLineEdit, QFormLayout, QDialogButtonBox, @@ -66,7 +64,7 @@ def refresh_suppliers(self): suppliers = self.suppliers_api.get_suppliers() self.update_table(suppliers) - def update_table(self, suppliers: List[Supplier]): + def update_table(self, suppliers: list[Supplier]): self.suppliers_table.setRowCount(len(suppliers)) self.suppliers_table.verticalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) for row, supplier in enumerate(suppliers): @@ -138,9 +136,8 @@ def delete_supplier(self, supplier: Supplier): QMessageBox.critical(self, "Error", f"Failed to delete supplier: {str(e)}") - class SupplierDialog(QDialog): - def __init__(self, suppliers_api: SuppliersAPI, supplier_data: Optional[Supplier] = None, parent=None): + def __init__(self, suppliers_api: SuppliersAPI, supplier_data: Supplier | None = None, parent=None): super().__init__(parent) self.suppliers_api = suppliers_api self.supplier_data = supplier_data diff --git a/desktop_app/src/ui/task_view.py b/desktop_app/src/ui/task_view.py index bc98df7..a322247 100644 --- a/desktop_app/src/ui/task_view.py +++ b/desktop_app/src/ui/task_view.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List, Optional from PySide6.QtCharts import QChart, QChartView, QPieSeries from PySide6.QtCore import Qt, Signal, QDate @@ -10,7 +9,7 @@ from desktop_app.src.ui.components import StyledButton from public_api.api import APIClient, TasksAPI, UsersAPI -from public_api.shared_schemas import TaskCreate, TaskUpdate, TaskWithAssignee, TaskFilter, UserSanitizedWithRole +from public_api.shared_schemas import TaskCreate, TaskUpdate, TaskWithAssignee, TaskFilter, UserSanitized from public_api.shared_schemas.task import TaskStatus, TaskPriority @@ -130,7 +129,7 @@ def refresh_tasks(self): tasks = self.tasks_api.get_tasks(filter_params=filter_params) self.update_task_table(tasks) - def update_task_table(self, tasks: List[TaskWithAssignee]): + def update_task_table(self, tasks: list[TaskWithAssignee]): self.task_table.setRowCount(len(tasks)) for row, task in enumerate(tasks): self.task_table.setItem(row, 0, QTableWidgetItem(task.task_type.value)) @@ -222,15 +221,14 @@ def refresh_statistics(self): class TaskDialog(QDialog): - def __init__(self, tasks_api: TasksAPI, users: List[UserSanitizedWithRole], - task_data: Optional[TaskWithAssignee] = None, parent=None): + def __init__(self, tasks_api: TasksAPI, users: list[UserSanitized], + task_data: TaskWithAssignee | None = None, parent=None): super().__init__(parent) self.tasks_api = tasks_api self.users = users self.task_data = task_data self.init_ui() - def init_ui(self): self.setWindowTitle("Create Task" if not self.task_data else "Edit Task") layout = QVBoxLayout(self) @@ -319,7 +317,7 @@ def accept(self): class TaskDetailsDialog(QDialog): - def __init__(self, task: TaskWithAssignee, users: List[UserSanitizedWithRole], parent=None): + def __init__(self, task: TaskWithAssignee, users: list[UserSanitized], parent=None): super().__init__(parent) self.task = task self.users = users @@ -366,4 +364,4 @@ def init_ui(self): close_button = QPushButton("Close") close_button.clicked.connect(self.accept) - main_layout.addWidget(close_button) \ No newline at end of file + main_layout.addWidget(close_button) diff --git a/desktop_app/src/ui/user_management.py b/desktop_app/src/ui/user_management.py index 7ffdc66..6561104 100644 --- a/desktop_app/src/ui/user_management.py +++ b/desktop_app/src/ui/user_management.py @@ -5,8 +5,8 @@ from desktop_app.src.ui.components import StyledButton from public_api.api import UsersAPI, APIClient -from public_api.shared_schemas import UserSanitizedWithRole, UserFilter, AllPermissions, UserWithPermissions, \ - UserCreate, UserUpdate, AllRoles +from public_api.shared_schemas import UserSanitized, UserFilter, AllPermissions, UserWithPermissions, \ + UserCreate, UserUpdate, AllRoles, UserPermissionUpdate class UserManagementWidget(QWidget): @@ -75,7 +75,7 @@ def refresh_users(self): users = self.users_api.get_users(filter_params=filter_params) self.update_table(users) - def update_table(self, users: list[UserSanitizedWithRole]): + def update_table(self, users: list[UserSanitized]): self.table.setRowCount(len(users)) self.table.verticalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) for row, user in enumerate(users): @@ -173,7 +173,7 @@ def delete_user(self, user_id): class UserDialog(QDialog): - def __init__(self, users_api: UsersAPI, user: UserSanitizedWithRole = None, parent=None): + def __init__(self, users_api: UsersAPI, user: UserSanitized = None, parent=None): super().__init__(parent) self.users_api = users_api self.user = user @@ -298,7 +298,8 @@ def save_permissions(self): selected_permissions.append(item.data(Qt.UserRole)) try: - self.users_api.update_user_permissions(self.user_id, selected_permissions) + permission_update = UserPermissionUpdate(permissions=selected_permissions) + self.users_api.update_user_permissions(self.user_id, permission_update) QMessageBox.information(self, "Success", "User permissions updated successfully.") self.accept() except Exception as e: diff --git a/public_api/api/__init__.py b/public_api/api/__init__.py index 631958d..f433005 100644 --- a/public_api/api/__init__.py +++ b/public_api/api/__init__.py @@ -15,6 +15,7 @@ from .purchase_orders import PurchaseOrdersAPI from .quality import QualityAPI from .reports import ReportsAPI +from .search import SearchAPI from .shipments import ShipmentsAPI from .suppliers import SuppliersAPI from .tasks import TasksAPI diff --git a/public_api/api/assets.py b/public_api/api/assets.py index 826fe55..4e3ccd1 100644 --- a/public_api/api/assets.py +++ b/public_api/api/assets.py @@ -1,5 +1,3 @@ -from typing import Optional - from public_api.shared_schemas import ( AssetCreate, AssetUpdate, Asset, AssetWithMaintenance, AssetFilter, AssetMaintenanceCreate, AssetMaintenanceUpdate, AssetMaintenance, @@ -18,7 +16,7 @@ def create_asset(self, asset_data: AssetCreate) -> Asset: return Asset.model_validate(response) def get_assets(self, skip: int = 0, limit: int = 100, - asset_filter: Optional[AssetFilter] = None) -> AssetWithMaintenanceList: + asset_filter: AssetFilter | None = None) -> AssetWithMaintenanceList: params = {"skip": skip, "limit": limit} if asset_filter: params.update(asset_filter.model_dump(mode="json", exclude_unset=True)) @@ -50,7 +48,7 @@ def create_asset_maintenance(self, maintenance_data: AssetMaintenanceCreate) -> return AssetMaintenance.model_validate(response) def get_asset_maintenances(self, skip: int = 0, limit: int = 100, - maintenance_filter: Optional[AssetMaintenanceFilter] = None) -> list[AssetMaintenance]: + maintenance_filter: AssetMaintenanceFilter | None = None) -> list[AssetMaintenance]: params = {"skip": skip, "limit": limit} if maintenance_filter: params.update(maintenance_filter.model_dump(mode="json", exclude_unset=True)) diff --git a/public_api/api/audit.py b/public_api/api/audit.py index e1d5276..0b3eb6b 100644 --- a/public_api/api/audit.py +++ b/public_api/api/audit.py @@ -1,5 +1,3 @@ -from typing import Optional, List - from public_api.shared_schemas import ( AuditLogCreate, AuditLog, AuditLogWithUser, AuditLogFilter, AuditSummary, AuditLogExport @@ -16,7 +14,7 @@ def create_audit_log(self, log: AuditLogCreate) -> AuditLog: return AuditLog.model_validate(response) def get_audit_logs(self, skip: int = 0, limit: int = 100, - filter_params: Optional[AuditLogFilter] = None) -> List[AuditLogWithUser]: + filter_params: AuditLogFilter | None = None) -> list[AuditLogWithUser]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) @@ -27,8 +25,8 @@ def get_audit_log(self, log_id: int) -> AuditLogWithUser: response = self.client.get(f"/audit/logs/{log_id}") return AuditLogWithUser.model_validate(response) - def get_audit_summary(self, date_from: Optional[int] = None, - date_to: Optional[int] = None) -> AuditSummary: + def get_audit_summary(self, date_from: int | None = None, + date_to: int | None = None) -> AuditSummary: params = {} if date_from: params["date_from"] = date_from @@ -37,22 +35,22 @@ def get_audit_summary(self, date_from: Optional[int] = None, response = self.client.get("/audit/logs/summary", params=params) return AuditSummary.model_validate(response) - def get_user_audit_logs(self, user_id: int, skip: int = 0, limit: int = 100) -> List[AuditLog]: + def get_user_audit_logs(self, user_id: int, skip: int = 0, limit: int = 100) -> list[AuditLog]: response = self.client.get(f"/audit/logs/user/{user_id}", params={"skip": skip, "limit": limit}) return [AuditLog.model_validate(item) for item in response] - def get_table_audit_logs(self, table_name: str, skip: int = 0, limit: int = 100) -> List[AuditLog]: + def get_table_audit_logs(self, table_name: str, skip: int = 0, limit: int = 100) -> list[AuditLog]: response = self.client.get(f"/audit/logs/table/{table_name}", params={"skip": skip, "limit": limit}) return [AuditLog.model_validate(item) for item in response] def get_record_audit_logs(self, table_name: str, record_id: int, - skip: int = 0, limit: int = 100) -> List[AuditLog]: + skip: int = 0, limit: int = 100) -> list[AuditLog]: response = self.client.get(f"/audit/logs/record/{table_name}/{record_id}", params={"skip": skip, "limit": limit}) return [AuditLog.model_validate(item) for item in response] - def export_audit_logs(self, date_from: Optional[int] = None, - date_to: Optional[int] = None) -> AuditLogExport: + def export_audit_logs(self, date_from: int | None = None, + date_to: int | None = None) -> AuditLogExport: params = {} if date_from: params["date_from"] = date_from @@ -61,10 +59,10 @@ def export_audit_logs(self, date_from: Optional[int] = None, response = self.client.get("/audit/logs/export", params=params) return AuditLogExport.model_validate(response) - def get_audit_log_actions(self) -> List[str]: + def get_audit_log_actions(self) -> list[str]: response = self.client.get("/audit/logs/actions") return [str(item) for item in response] - def get_audited_tables(self) -> List[str]: + def get_audited_tables(self) -> list[str]: response = self.client.get("/audit/logs/tables") return [str(item) for item in response] diff --git a/public_api/api/client.py b/public_api/api/client.py index ff2f2db..38091bb 100644 --- a/public_api/api/client.py +++ b/public_api/api/client.py @@ -1,34 +1,74 @@ +# public_api/api/client.py +from datetime import datetime, timedelta + import requests +from requests import HTTPError + +from public_api.shared_schemas import Token class APIClient: def __init__(self, base_url: str): self.base_url = base_url self.session = requests.Session() - self.token = None + self.access_token: str | None = None + self.refresh_token: str | None = None + self.token_expiry: datetime | None = None - def set_token(self, token): - self.token = token - self.session.headers.update({"Authorization": f"Bearer {token}"}) + def set_tokens(self, access_token: str, refresh_token: str, expires_in: int): + self.access_token = access_token + self.refresh_token = refresh_token + self.token_expiry = datetime.utcnow() + timedelta(seconds=expires_in) + self.session.headers.update({"Authorization": f"Bearer {access_token}"}) - def get(self, endpoint, params=None, headers=None): - response = self.session.get(f"{self.base_url}{endpoint}", params=params, headers=headers) - response.raise_for_status() - return response.json() + def refresh_access_token(self) -> bool: + if not self.refresh_token: + return False - def post(self, endpoint, data=None, json=None, headers=None, params=None): - response = self.session.post(f"{self.base_url}{endpoint}", - data=data, json=json, headers=headers, - params=params) - response.raise_for_status() - return response.json() + try: + response = self.request_call("POST", "/users/refresh-token", + json={ + "refresh_token": self.refresh_token + }) + token = Token.model_validate(response) + self.set_tokens(token.access_token, token.refresh_token, token.expires_in) + return True + except Exception as e: + print(f"Error refreshing token: {e}") + return False - def put(self, endpoint, data=None, json=None, headers=None): - response = self.session.put(f"{self.base_url}{endpoint}", data=data, json=json, headers=headers) - response.raise_for_status() - return response.json() + def is_token_expired(self) -> bool: + if self.token_expiry is None: + return self.access_token is None + return datetime.utcnow() >= self.token_expiry + + def request(self, method: str, endpoint: str, **kwargs): + if self.is_token_expired(): + if self.refresh_token and not self.refresh_access_token(): + raise HTTPError("Unable to refresh token") - def delete(self, endpoint, headers=None): - response = self.session.delete(f"{self.base_url}{endpoint}", headers=headers) + try: + return self.request_call(method, endpoint, **kwargs) + except HTTPError as e: + if e.response.status_code == 401: # Unauthorized + if self.refresh_access_token(): + return self.request_call(method, endpoint, **kwargs) + raise + + def request_call(self, method: str, endpoint: str, **kwargs): + response = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs) response.raise_for_status() return response.json() + + def get(self, endpoint: str, params: dict | None = None, headers: dict | None = None): + return self.request("GET", endpoint, params=params, headers=headers) + + def post(self, endpoint: str, data: dict | None = None, json: dict | None = None, headers: dict | None = None, + params: dict | None = None): + return self.request("POST", endpoint, data=data, json=json, headers=headers, params=params) + + def put(self, endpoint: str, data: dict | None = None, json: dict | None = None, headers: dict | None = None): + return self.request("PUT", endpoint, data=data, json=json, headers=headers) + + def delete(self, endpoint: str, headers: dict | None = None): + return self.request("DELETE", endpoint, headers=headers) diff --git a/public_api/api/customers.py b/public_api/api/customers.py index 26d7e69..8eb4033 100644 --- a/public_api/api/customers.py +++ b/public_api/api/customers.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( CustomerCreate, CustomerUpdate, Customer, CustomerFilter, Order ) @@ -15,7 +13,7 @@ def create_customer(self, customer_data: CustomerCreate) -> Customer: return Customer.model_validate(response) def get_customers(self, skip: int = 0, limit: int = 100, - customer_filter: Optional[CustomerFilter] = None) -> List[Customer]: + customer_filter: CustomerFilter | None = None) -> list[Customer]: params = {"skip": skip, "limit": limit} if customer_filter: params.update(customer_filter.model_dump(mode="json", exclude_unset=True)) @@ -35,6 +33,6 @@ def delete_customer(self, customer_id: int) -> Customer: response = self.client.delete(f"/customers/{customer_id}") return Customer.model_validate(response) - def get_customer_orders(self, customer_id: int, skip: int = 0, limit: int = 100) -> List[Order]: + def get_customer_orders(self, customer_id: int, skip: int = 0, limit: int = 100) -> list[Order]: response = self.client.get(f"/customers/{customer_id}/orders", params={"skip": skip, "limit": limit}) return [Order.model_validate(item) for item in response] diff --git a/public_api/api/inventory.py b/public_api/api/inventory.py index 9e5df5c..17f8e36 100644 --- a/public_api/api/inventory.py +++ b/public_api/api/inventory.py @@ -1,5 +1,3 @@ -from typing import Optional, List, Dict - from public_api.shared_schemas import ( InventoryCreate, InventoryUpdate, Inventory, InventoryList, InventoryFilter, InventoryTransfer, InventoryReport, ProductWithInventory, Product, @@ -19,7 +17,7 @@ def create_inventory(self, inventory_data: InventoryCreate) -> Inventory: return Inventory.model_validate(response) def get_inventory(self, skip: int = 0, limit: int = 100, - inventory_filter: Optional[InventoryFilter] = None) -> InventoryList: + inventory_filter: InventoryFilter | None = None) -> InventoryList: params = {"skip": skip, "limit": limit} if inventory_filter: params.update(inventory_filter.model_dump(mode="json", exclude_unset=True)) @@ -47,36 +45,36 @@ def get_inventory_report(self) -> InventoryReport: response = self.client.get("/inventory/report") return InventoryReport.model_validate(response) - def perform_cycle_count(self, location_id: int, counted_items: List[InventoryUpdate]) -> List[Inventory]: + def perform_cycle_count(self, location_id: int, counted_items: list[InventoryUpdate]) -> list[Inventory]: response = self.client.post("/inventory/cycle_count", json={ "location_id": location_id, "counted_items": [item.model_dump(mode="json") for item in counted_items] }) return [Inventory.model_validate(item) for item in response] - def get_low_stock_items(self, threshold: int = 10) -> List[ProductWithInventory]: + def get_low_stock_items(self, threshold: int = 10) -> list[ProductWithInventory]: response = self.client.get("/inventory/low_stock", params={"threshold": threshold}) return [ProductWithInventory.model_validate(item) for item in response] - def get_out_of_stock_items(self) -> List[Product]: + def get_out_of_stock_items(self) -> list[Product]: response = self.client.get("/inventory/out_of_stock") return [Product.model_validate(item) for item in response] - def create_reorder_list(self, threshold: int = 10) -> List[Product]: + def create_reorder_list(self, threshold: int = 10) -> list[Product]: response = self.client.post("/inventory/reorder", params={"threshold": threshold}) return [Product.model_validate(item) for item in response] - def get_product_locations(self, product_id: int) -> List[LocationWithInventory]: + def get_product_locations(self, product_id: int) -> list[LocationWithInventory]: response = self.client.get(f"/inventory/product_locations/{product_id}") return [LocationWithInventory.model_validate(item) for item in response] - def batch_update_inventory(self, updates: List[InventoryUpdate]) -> List[Inventory]: + def batch_update_inventory(self, updates: list[InventoryUpdate]) -> list[Inventory]: response = self.client.post("/inventory/batch_update", json=[update.model_dump(mode="json") for update in updates]) return [Inventory.model_validate(item) for item in response] - def get_inventory_movement_history(self, product_id: int, start_date: Optional[int] = None, - end_date: Optional[int] = None) -> List[InventoryMovement]: + def get_inventory_movement_history(self, product_id: int, start_date: int | None = None, + end_date: int | None = None) -> list[InventoryMovement]: params = {} if start_date: params["start_date"] = start_date @@ -97,11 +95,11 @@ def perform_abc_analysis(self) -> ABCAnalysisResult: response = self.client.get("/inventory/abc_analysis") return ABCAnalysisResult.model_validate(response) - def optimize_inventory_locations(self) -> List[InventoryLocationSuggestion]: + def optimize_inventory_locations(self) -> list[InventoryLocationSuggestion]: response = self.client.post("/inventory/optimize_locations") return [InventoryLocationSuggestion.model_validate(item) for item in response] - def get_expiring_soon_inventory(self, days: int = 30) -> List[ProductWithInventory]: + def get_expiring_soon_inventory(self, days: int = 30) -> list[ProductWithInventory]: response = self.client.get("/inventory/expiring_soon", params={"days": days}) return [ProductWithInventory.model_validate(item) for item in response] @@ -113,11 +111,11 @@ def get_storage_utilization(self) -> StorageUtilization: response = self.client.get("/inventory/storage_utilization") return StorageUtilization.model_validate(response) - def get_inventory_forecast(self, product_id: int) -> Dict: + def get_inventory_forecast(self, product_id: int) -> dict: response = self.client.get(f"/inventory/forecast/{product_id}") return response - def get_reorder_suggestions(self) -> List[Dict]: + def get_reorder_suggestions(self) -> list[dict]: response = self.client.get("/inventory/reorder_suggestions") return response diff --git a/public_api/api/locations.py b/public_api/api/locations.py index 3d93daf..4345c95 100644 --- a/public_api/api/locations.py +++ b/public_api/api/locations.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas.inventory import ( LocationCreate, LocationUpdate, Location, LocationWithInventory, LocationFilter ) @@ -15,7 +13,7 @@ def create_location(self, location_data: LocationCreate) -> Location: return Location.model_validate(response) def get_locations(self, skip: int = 0, limit: int = 100, - location_filter: Optional[LocationFilter] = None) -> List[LocationWithInventory]: + location_filter: LocationFilter | None = None) -> list[LocationWithInventory]: params = {"skip": skip, "limit": limit} if location_filter: params.update(location_filter.model_dump(mode="json", exclude_unset=True)) diff --git a/public_api/api/orders.py b/public_api/api/orders.py index 3a243f3..b622f88 100644 --- a/public_api/api/orders.py +++ b/public_api/api/orders.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( OrderCreate, OrderUpdate, Order, OrderWithDetails, OrderFilter, OrderSummary, ShippingInfo, OrderItemCreate, BulkOrderImportData, @@ -17,7 +15,7 @@ def create_order(self, order_data: OrderCreate) -> Order: return Order.model_validate(response) def get_orders(self, skip: int = 0, limit: int = 100, - filter_params: Optional[OrderFilter] = None) -> List[OrderWithDetails]: + filter_params: OrderFilter | None = None) -> list[OrderWithDetails]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) @@ -47,8 +45,8 @@ def delete_order(self, order_id: int) -> Order: response = self.client.delete(f"/orders/{order_id}") return Order.model_validate(response) - def get_order_summary(self, date_from: Optional[int] = None, - date_to: Optional[int] = None) -> OrderSummary: + def get_order_summary(self, date_from: int | None = None, + date_to: int | None = None) -> OrderSummary: params = {} if date_from: params["date_from"] = date_from @@ -73,7 +71,7 @@ def add_order_item(self, order_id: int, item_data: OrderItemCreate) -> Order: response = self.client.post(f"/orders/{order_id}/add_item", json=item_data.model_dump(mode="json")) return Order.model_validate(response) - def get_backorders(self) -> List[Order]: + def get_backorders(self) -> list[Order]: response = self.client.get("/orders/backorders") return [Order.model_validate(item) for item in response] diff --git a/public_api/api/pick_lists.py b/public_api/api/pick_lists.py index 1909717..bf77522 100644 --- a/public_api/api/pick_lists.py +++ b/public_api/api/pick_lists.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( PickListCreate, PickListUpdate, PickList, PickListFilter, OptimizedPickingRoute, PickingPerformance @@ -16,7 +14,7 @@ def create_pick_list(self, pick_list_data: PickListCreate) -> PickList: return PickList.model_validate(response) def get_pick_lists(self, skip: int = 0, limit: int = 100, - filter_params: Optional[PickListFilter] = None) -> List[PickList]: + filter_params: PickListFilter | None = None) -> list[PickList]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) diff --git a/public_api/api/products.py b/public_api/api/products.py index b73f244..248cc3a 100644 --- a/public_api/api/products.py +++ b/public_api/api/products.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas.inventory import ( ProductCreate, ProductUpdate, Product, ProductWithCategoryAndInventory, ProductFilter, BarcodeData @@ -16,7 +14,7 @@ def create_product(self, product_data: ProductCreate) -> Product: return Product.model_validate(response) def get_products(self, skip: int = 0, limit: int = 100, - product_filter: Optional[ProductFilter] = None) -> List[ProductWithCategoryAndInventory]: + product_filter: ProductFilter | None = None) -> list[ProductWithCategoryAndInventory]: params = {"skip": skip, "limit": limit} if product_filter: params.update(product_filter.model_dump(mode="json", exclude_unset=True)) @@ -41,7 +39,7 @@ def get_product_by_barcode(self, barcode: str) -> Product: response = self.client.post("/products/barcode", json=barcode_data.model_dump(mode="json")) return Product.model_validate(response) - def get_product_substitutes(self, product_id: int) -> List[Product]: + def get_product_substitutes(self, product_id: int) -> list[Product]: response = self.client.get(f"/products/{product_id}/substitutes") return [Product.model_validate(item) for item in response] diff --git a/public_api/api/purchase_orders.py b/public_api/api/purchase_orders.py index d512d4c..26cd6c1 100644 --- a/public_api/api/purchase_orders.py +++ b/public_api/api/purchase_orders.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( PurchaseOrderCreate, PurchaseOrderUpdate, PurchaseOrder, PurchaseOrderWithDetails, PurchaseOrderFilter, POItemReceive @@ -16,7 +14,7 @@ def create_purchase_order(self, purchase_order_data: PurchaseOrderCreate) -> Pur return PurchaseOrder.model_validate(response) def get_purchase_orders(self, skip: int = 0, limit: int = 100, - po_filter: Optional[PurchaseOrderFilter] = None) -> List[PurchaseOrderWithDetails]: + po_filter: PurchaseOrderFilter | None = None) -> list[PurchaseOrderWithDetails]: params = {"skip": skip, "limit": limit} if po_filter: params.update(po_filter.model_dump(mode="json", exclude_unset=True)) @@ -36,7 +34,7 @@ def delete_purchase_order(self, po_id: int) -> PurchaseOrder: response = self.client.delete(f"/purchase_orders/{po_id}") return PurchaseOrder.model_validate(response) - def receive_purchase_order(self, po_id: int, received_items: List[POItemReceive]) -> PurchaseOrder: + def receive_purchase_order(self, po_id: int, received_items: list[POItemReceive]) -> PurchaseOrder: response = self.client.post(f"/purchase_orders/{po_id}/receive", json={"received_items": [item.model_dump(mode="json") for item in received_items]}) return PurchaseOrder.model_validate(response) diff --git a/public_api/api/quality.py b/public_api/api/quality.py index 2365f89..99b3427 100644 --- a/public_api/api/quality.py +++ b/public_api/api/quality.py @@ -1,5 +1,3 @@ -from typing import Optional, List - from public_api.shared_schemas.quality import ( QualityCheckCreate, QualityCheckUpdate, QualityCheckWithProduct, QualityCheckFilter, QualityMetrics, QualityStandardCreate, QualityStandardUpdate, QualityStandard, @@ -18,7 +16,7 @@ def create_quality_check(self, check_data: QualityCheckCreate) -> QualityCheckWi return QualityCheckWithProduct.model_validate(response) def get_quality_checks(self, skip: int = 0, limit: int = 100, - filter_params: Optional[QualityCheckFilter] = None) -> List[QualityCheckWithProduct]: + filter_params: QualityCheckFilter | None = None) -> list[QualityCheckWithProduct]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) @@ -38,8 +36,7 @@ def delete_quality_check(self, check_id: int) -> QualityCheckWithProduct: response = self.client.delete(f"/quality/checks/{check_id}") return QualityCheckWithProduct.model_validate(response) - def get_quality_metrics(self, date_from: Optional[int] = None, - date_to: Optional[int] = None) -> QualityMetrics: + def get_quality_metrics(self, date_from: int | None = None, date_to: int | None = None) -> QualityMetrics: params = {} if date_from: params["date_from"] = date_from @@ -52,7 +49,7 @@ def create_quality_standard(self, standard_data: QualityStandardCreate) -> Quali response = self.client.post("/quality/standards", json=standard_data.model_dump(mode="json")) return QualityStandard.model_validate(response) - def get_quality_standards(self, skip: int = 0, limit: int = 100) -> List[QualityStandard]: + def get_quality_standards(self, skip: int = 0, limit: int = 100) -> list[QualityStandard]: response = self.client.get("/quality/standards", params={"skip": skip, "limit": limit}) return [QualityStandard.model_validate(item) for item in response] @@ -73,7 +70,7 @@ def create_quality_alert(self, alert_data: QualityAlertCreate) -> QualityAlert: response = self.client.post("/quality/alerts", json=alert_data.model_dump(mode="json")) return QualityAlert.model_validate(response) - def get_quality_alerts(self, skip: int = 0, limit: int = 100) -> List[QualityAlert]: + def get_quality_alerts(self, skip: int = 0, limit: int = 100) -> list[QualityAlert]: response = self.client.get("/quality/alerts", params={"skip": skip, "limit": limit}) return [QualityAlert.model_validate(item) for item in response] @@ -83,13 +80,13 @@ def resolve_quality_alert(self, alert_id: int, resolution_data: QualityAlertUpda return QualityAlert.model_validate(response) def get_product_quality_history(self, product_id: int, - skip: int = 0, limit: int = 100) -> List[QualityCheckWithProduct]: + skip: int = 0, limit: int = 100) -> list[QualityCheckWithProduct]: response = self.client.get(f"/quality/product/{product_id}/history", params={"skip": skip, "limit": limit}) return [QualityCheckWithProduct.model_validate(item) for item in response] - def get_quality_check_summary(self, date_from: Optional[int] = None, - date_to: Optional[int] = None) -> QualityCheckSummary: + def get_quality_check_summary(self, + date_from: int | None = None, date_to: int | None = None) -> QualityCheckSummary: params = {} if date_from: params["date_from"] = date_from @@ -98,16 +95,16 @@ def get_quality_check_summary(self, date_from: Optional[int] = None, response = self.client.get("/quality/checks/summary", params=params) return QualityCheckSummary.model_validate(response) - def get_product_quality_standards(self, product_id: int) -> List[QualityStandard]: + def get_product_quality_standards(self, product_id: int) -> list[QualityStandard]: response = self.client.get(f"/quality/product/{product_id}/standards") return [QualityStandard.model_validate(item) for item in response] - def create_batch_quality_check(self, checks: List[QualityCheckCreate]) -> List[QualityCheckWithProduct]: + def create_batch_quality_check(self, checks: list[QualityCheckCreate]) -> list[QualityCheckWithProduct]: response = self.client.post("/quality/batch_check", json=[check.model_dump(mode="json") for check in checks]) return [QualityCheckWithProduct.model_validate(item) for item in response] - def get_active_quality_alerts(self, skip: int = 0, limit: int = 100) -> List[QualityAlert]: + def get_active_quality_alerts(self, skip: int = 0, limit: int = 100) -> list[QualityAlert]: response = self.client.get("/quality/alerts/active", params={"skip": skip, "limit": limit}) return [QualityAlert.model_validate(item) for item in response] @@ -118,8 +115,8 @@ def add_comment_to_quality_check(self, check_id: int, return QualityCheckComment.model_validate(response) def get_product_defect_rates(self, - date_from: Optional[int] = None, - date_to: Optional[int] = None) -> List[ProductDefectRate]: + date_from: int | None = None, + date_to: int | None = None) -> list[ProductDefectRate]: params = {} if date_from: params["date_from"] = date_from diff --git a/public_api/api/search.py b/public_api/api/search.py new file mode 100644 index 0000000..3e2f879 --- /dev/null +++ b/public_api/api/search.py @@ -0,0 +1,53 @@ +from public_api.shared_schemas import Product, Order +from .client import APIClient + + +class SearchAPI: + def __init__(self, client: APIClient): + self.client = client + + def search_products( + self, + q: str | None = None, + category_id: int | None = None, + min_price: float | None = None, + max_price: float | None = None, + in_stock: str | bool = None, + sort_by: str | None = None, + sort_order: str | None = "asc" + ) -> list[Product]: + params = { + "q": q, + "category_id": category_id, + "min_price": min_price, + "max_price": max_price, + "in_stock": in_stock, + "sort_by": sort_by, + "sort_order": sort_order + } + response = self.client.get("/search/products", params={k: v for k, v in params.items() if v is not None}) + return [Product.model_validate(item) for item in response] + + def search_orders( + self, + q: str | None = None, + status: str | None = None, + min_total: float | None = None, + max_total: float | None = None, + start_date: int | None = None, + end_date: int | None = None, + sort_by: str | None = None, + sort_order: str | None = "asc" + ) -> list[Order]: + params = { + "q": q, + "status": status, + "min_total": min_total, + "max_total": max_total, + "start_date": start_date, + "end_date": end_date, + "sort_by": sort_by, + "sort_order": sort_order + } + response = self.client.get("/search/orders", params={k: v for k, v in params.items() if v is not None}) + return [Order.model_validate(item) for item in response] diff --git a/public_api/api/shipments.py b/public_api/api/shipments.py index 268ef61..c0747e1 100644 --- a/public_api/api/shipments.py +++ b/public_api/api/shipments.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( Shipment, ShipmentCreate, ShipmentUpdate, ShipmentFilter, CarrierRate, ShippingLabel, ShipmentTracking, ShipmentWithDetails @@ -12,7 +10,7 @@ def __init__(self, client: APIClient): self.client = client def get_shipments(self, skip: int = 0, limit: int = 100, - filter_params: Optional[ShipmentFilter] = None) -> List[Shipment]: + filter_params: ShipmentFilter | None = None) -> list[Shipment]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) @@ -40,7 +38,7 @@ def generate_shipping_label(self, shipment_id: int) -> ShippingLabel: response = self.client.post(f"/shipments/{shipment_id}/generate_label") return ShippingLabel.model_validate(response) - def get_carrier_rates(self, weight: float, dimensions: str, destination_zip: str) -> List[CarrierRate]: + def get_carrier_rates(self, weight: float, dimensions: str, destination_zip: str) -> list[CarrierRate]: params = {"weight": weight, "dimensions": dimensions, "destination_zip": destination_zip} response = self.client.get("/shipments/carrier_rates", params=params) return [CarrierRate.model_validate(item) for item in response] diff --git a/public_api/api/suppliers.py b/public_api/api/suppliers.py index a48dc2d..1fe38b1 100644 --- a/public_api/api/suppliers.py +++ b/public_api/api/suppliers.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( Supplier, SupplierCreate, SupplierUpdate, SupplierFilter, PurchaseOrder ) @@ -11,7 +9,7 @@ def __init__(self, client: APIClient): self.client = client def get_suppliers(self, skip: int = 0, limit: int = 100, - filter_params: Optional[SupplierFilter] = None) -> List[Supplier]: + filter_params: SupplierFilter | None = None) -> list[Supplier]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(mode="json", exclude_unset=True)) @@ -36,7 +34,7 @@ def delete_supplier(self, supplier_id: int) -> Supplier: return Supplier.model_validate(response) def get_supplier_purchase_orders(self, supplier_id: int, - skip: int = 0, limit: int = 100) -> List[PurchaseOrder]: + skip: int = 0, limit: int = 100) -> list[PurchaseOrder]: response = self.client.get(f"/suppliers/{supplier_id}/purchase_orders", params={"skip": skip, "limit": limit}) return [PurchaseOrder.model_validate(item) for item in response] diff --git a/public_api/api/tasks.py b/public_api/api/tasks.py index 3c41450..b0989e3 100644 --- a/public_api/api/tasks.py +++ b/public_api/api/tasks.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from public_api.shared_schemas import ( Task, TaskCreate, TaskUpdate, TaskWithAssignee, TaskFilter, TaskStatistics, UserTaskSummary, TaskComment, TaskCommentCreate @@ -15,8 +13,9 @@ def create_task(self, task: TaskCreate) -> Task: response = self.client.post("/tasks/", json=task.model_dump()) return Task.model_validate(response) - def get_tasks(self, skip: int = 0, limit: int = 100, filter_params: Optional[TaskFilter] = None) -> List[ - TaskWithAssignee]: + def get_tasks(self, + skip: int = 0, limit: int = 100, + filter_params: TaskFilter | None = None) -> list[TaskWithAssignee]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(exclude_none=True)) @@ -27,20 +26,20 @@ def get_task_statistics(self) -> TaskStatistics: response = self.client.get("/tasks/statistics") return TaskStatistics.model_validate(response) - def get_user_task_summary(self) -> List[UserTaskSummary]: + def get_user_task_summary(self) -> list[UserTaskSummary]: response = self.client.get("/tasks/user_summary") return [UserTaskSummary.model_validate(item) for item in response] - def get_overdue_tasks(self, skip: int = 0, limit: int = 100) -> List[TaskWithAssignee]: + def get_overdue_tasks(self, skip: int = 0, limit: int = 100) -> list[TaskWithAssignee]: params = {"skip": skip, "limit": limit} response = self.client.get("/tasks/overdue", params=params) return [TaskWithAssignee.model_validate(item) for item in response] - def create_batch_tasks(self, tasks: List[TaskCreate]) -> List[Task]: + def create_batch_tasks(self, tasks: list[TaskCreate]) -> list[Task]: response = self.client.post("/tasks/batch_create", json=[task.model_dump() for task in tasks]) return [Task.model_validate(item) for item in response] - def get_my_tasks(self, skip: int = 0, limit: int = 100) -> List[Task]: + def get_my_tasks(self, skip: int = 0, limit: int = 100) -> list[Task]: params = {"skip": skip, "limit": limit} response = self.client.get("/tasks/my_tasks", params=params) return [Task.model_validate(item) for item in response] @@ -65,7 +64,7 @@ def add_task_comment(self, task_id: int, comment: TaskCommentCreate) -> TaskComm response = self.client.post(f"/tasks/{task_id}/comment", json=comment.model_dump()) return TaskComment.model_validate(response) - def get_task_comments(self, task_id: int, skip: int = 0, limit: int = 100) -> List[TaskComment]: + def get_task_comments(self, task_id: int, skip: int = 0, limit: int = 100) -> list[TaskComment]: params = {"skip": skip, "limit": limit} response = self.client.get(f"/tasks/{task_id}/comments", params=params) return [TaskComment.model_validate(item) for item in response] diff --git a/public_api/api/users.py b/public_api/api/users.py index 0ebb293..f7b3ad9 100644 --- a/public_api/api/users.py +++ b/public_api/api/users.py @@ -1,10 +1,7 @@ -from typing import List, Optional - from public_api.shared_schemas import ( - UserCreate, UserUpdate, User, UserSanitizedWithRole, Token, - Message, PasswordResetConfirm, UserFilter, - UserActivity, RoleWithUsers, BulkUserCreate, - BulkUserCreateResult, AllPermissions, AllRoles, UserWithPermissions, UserPermissionUpdate + UserCreate, UserUpdate, UserSanitized, Token, + Message, UserFilter, AllPermissions, AllRoles, UserWithPermissions, + UserPermissionUpdate, TwoFactorLogin ) from .client import APIClient from ..permission_manager import PermissionManager @@ -17,130 +14,110 @@ def __init__(self, client: APIClient): def login(self, username: str, password: str) -> Token: data = { - "grant_type": "password", "username": username, "password": password, } - headers = {"content-type": "application/x-www-form-urlencoded"} - - response = self.client.post("/users/login", data=data, headers=headers) - access_token = response.get("access_token") - - if access_token: - self.client.set_token(access_token) - - return Token.model_validate(response) + response = self.client.post("/users/login", data=data) + token = Token.model_validate(response) + self.client.set_tokens(token.access_token, token.refresh_token, token.expires_in) + return token def login_2fa(self, username: str, password: str, two_factor_code: str) -> Token: - data = { - "username": username, - "password": password, - "two_factor_code": two_factor_code - } - response = self.client.post("/users/login/2fa", json=data) - access_token = response.get("access_token") - - if access_token: - self.client.set_token(access_token) + data = TwoFactorLogin( + username=username, + password=password, + two_factor_code=two_factor_code + ) + response = self.client.post("/users/login/2fa", json=data.model_dump()) + token = Token.model_validate(response) + self.client.set_tokens(token.access_token, token.refresh_token, token.expires_in) + return token + + def register(self, user: UserCreate) -> UserSanitized: + response = self.client.post("/users/register", json=user.model_dump()) + return UserSanitized.model_validate(response) - return Token.model_validate(response) - - def register(self, user: UserCreate) -> User: - response = self.client.post("/users/register", json=user.model_dump(mode="json")) - return User.model_validate(response) + def reset_password(self, email: str) -> Message: + response = self.client.post("/users/reset_password", json={"email": email}) + return Message.model_validate(response) def change_password(self, current_password: str, new_password: str) -> Message: data = {"current_password": current_password, "new_password": new_password} response = self.client.post("/users/change_password", json=data) return Message.model_validate(response) - def reset_password(self, email: str) -> Message: - response = self.client.post("/users/reset_password", json={"email": email}) - return Message.model_validate(response) + def refresh_token(self) -> Token: + refresh_token = self.client.refresh_token + if not refresh_token: + raise ValueError("No refresh token available") - def get_current_user(self) -> UserSanitizedWithRole: - response = self.client.get("/users/me") - return UserSanitizedWithRole.model_validate(response) + response = self.client.post("/users/refresh-token", json={ + "refresh_token": refresh_token + }) + token = Token.model_validate(response) + self.client.set_tokens(token.access_token, token.refresh_token, token.expires_in) + return token - def get_current_user_permissions(self) -> PermissionManager: - if not self._permission_manager: - response = self.client.get("/users/my_permissions") - permissions = response.get('permissions', []) - self._permission_manager = PermissionManager(permissions) - return self._permission_manager + def update_current_user(self, user_update: UserUpdate) -> UserSanitized: + response = self.client.put("/users/me", json=user_update.model_dump(exclude_unset=True)) + return UserSanitized.model_validate(response) - def has_permission(self, permission_name: str, action: str) -> bool: - return self.get_current_user_permissions().has_permission(permission_name, action) + def get_current_user(self) -> UserSanitized: + response = self.client.get("/users/me") + return UserSanitized.model_validate(response) - def clear_permissions_cache(self): - self._permission_manager = None + def get_all_permissions(self) -> AllPermissions: + response = self.client.get("/users/permissions") + return AllPermissions.model_validate(response) + + def get_my_permissions(self) -> AllPermissions: + response = self.client.get("/users/my_permissions") + return AllPermissions.model_validate(response) - def update_current_user(self, user_update: UserUpdate) -> User: - response = self.client.put("/users/me", json=user_update.model_dump(mode="json", exclude_unset=True)) - return User.model_validate(response) + def get_all_roles(self) -> AllRoles: + response = self.client.get("/users/roles") + return AllRoles.model_validate(response) - def get_users(self, filter_params: Optional[UserFilter] = None, - skip: int = 0, limit: int = 100) -> List[UserSanitizedWithRole]: + def get_users(self, filter_params: UserFilter | None = None, + skip: int = 0, limit: int = 100) -> list[UserSanitized]: params = {"skip": skip, "limit": limit} if filter_params: params.update(filter_params.model_dump(exclude_unset=True)) response = self.client.get("/users/", params=params) - return [UserSanitizedWithRole.model_validate(item) for item in response] + return [UserSanitized.model_validate(item) for item in response] - def create_user(self, user: UserCreate) -> UserSanitizedWithRole: - response = self.client.post("/users/", json=user.model_dump(mode="json")) - return UserSanitizedWithRole.model_validate(response) + def create_user(self, user: UserCreate) -> UserSanitized: + response = self.client.post("/users/", json=user.model_dump()) + return UserSanitized.model_validate(response) - def get_user(self, user_id: int) -> UserSanitizedWithRole: - response = self.client.get(f"/users/{user_id}") - return UserSanitizedWithRole.model_validate(response) - - def update_user(self, user_id: int, user_update: UserUpdate) -> User: - response = self.client.put(f"/users/{user_id}", json=user_update.model_dump(mode="json", exclude_unset=True)) - return User.model_validate(response) - - def delete_user(self, user_id: int) -> User: - response = self.client.delete(f"/users/{user_id}") - return User.model_validate(response) - - def confirm_password_reset(self, token: str, new_password: str) -> Message: - data = PasswordResetConfirm(token=token, new_password=new_password) - response = self.client.post("/users/reset_password_confirm", json=data.model_dump(mode="json")) - return Message.model_validate(response) - - def get_filtered_users(self, user_filter: UserFilter) -> List[UserSanitizedWithRole]: - response = self.client.get("/users/filter", params=user_filter.model_dump(mode="json", exclude_unset=True)) - return [UserSanitizedWithRole.model_validate(item) for item in response] - - def get_user_activity(self) -> List[UserActivity]: - response = self.client.get("/users/activity") - return [UserActivity.model_validate(item) for item in response] + def get_user_permissions(self, user_id: int) -> UserWithPermissions: + response = self.client.get(f"/users/{user_id}/permissions") + return UserWithPermissions.model_validate(response) - def get_role_with_users(self, role_id: int) -> RoleWithUsers: - response = self.client.get(f"/users/role/{role_id}") - return RoleWithUsers.model_validate(response) + def update_user_permissions(self, user_id: int, permission_update: UserPermissionUpdate) -> UserWithPermissions: + response = self.client.put(f"/users/{user_id}/permissions", json=permission_update.model_dump()) + return UserWithPermissions.model_validate(response) - def bulk_create_users(self, users: BulkUserCreate) -> BulkUserCreateResult: - response = self.client.post("/users/bulk", json=users.model_dump(mode="json")) - return BulkUserCreateResult.model_validate(response) + def get_user(self, user_id: int) -> UserSanitized: + response = self.client.get(f"/users/{user_id}") + return UserSanitized.model_validate(response) - def get_all_permissions(self) -> AllPermissions: - response = self.client.get("/users/permissions") - return AllPermissions.model_validate(response) + def update_user(self, user_id: int, user_update: UserUpdate) -> UserSanitized: + response = self.client.put(f"/users/{user_id}", json=user_update.model_dump(exclude_unset=True)) + return UserSanitized.model_validate(response) - def get_all_roles(self) -> AllRoles: - response = self.client.get("/users/roles") - return AllRoles.model_validate(response) + def delete_user(self, user_id: int) -> UserSanitized: + response = self.client.delete(f"/users/{user_id}") + return UserSanitized.model_validate(response) - def my_permissions(self) -> AllPermissions: - response = self.client.get("/users/my_permissions") - return AllPermissions.model_validate(response) + def get_current_user_permissions(self) -> PermissionManager: + if not self._permission_manager: + response = self.get_my_permissions() + self._permission_manager = PermissionManager(response.permissions) + return self._permission_manager - def get_user_permissions(self, user_id: int) -> UserWithPermissions: - response = self.client.get(f"/users/{user_id}/permissions") - return UserWithPermissions.model_validate(response) + def has_permission(self, permission_name: str, action: str) -> bool: + return self.get_current_user_permissions().has_permission(permission_name, action) - def update_user_permissions(self, user_id: int, permissions: List[int]) -> UserWithPermissions: - data = UserPermissionUpdate(user_id=user_id, permissions=permissions) - response = self.client.put(f"/users/{user_id}/permissions", json=data.model_dump()) - return UserWithPermissions.model_validate(response) + def clear_permissions_cache(self): + self._permission_manager = None diff --git a/public_api/permission_manager.py b/public_api/permission_manager.py index d4889e5..21372fd 100644 --- a/public_api/permission_manager.py +++ b/public_api/permission_manager.py @@ -1,18 +1,19 @@ -from typing import List, Dict, Any +from typing import List +from public_api.shared_schemas.user import Permission class PermissionManager: - def __init__(self, permissions: List[Dict[str, Any]]): - self.permissions = {p['permission_name']: p for p in permissions} + def __init__(self, permissions: List[Permission]): + self.permissions = {p.permission_name: p for p in permissions} def has_permission(self, permission_name: str, action: str) -> bool: if permission_name in self.permissions: permission = self.permissions[permission_name] - if action == 'read' and permission['can_read']: + if action == 'read' and permission.can_read: return True - if action == 'write' and permission['can_write']: + if action == 'write' and permission.can_write: return True - if action == 'delete' and permission['can_delete']: + if action == 'delete' and permission.can_delete: return True return False @@ -23,4 +24,4 @@ def has_write_permission(self, permission_name: str) -> bool: return self.has_permission(permission_name, 'write') def has_delete_permission(self, permission_name: str) -> bool: - return self.has_permission(permission_name, 'delete') + return self.has_permission(permission_name, 'delete') \ No newline at end of file diff --git a/public_api/shared_schemas/__init__.py b/public_api/shared_schemas/__init__.py index 619baf8..0f6deec 100644 --- a/public_api/shared_schemas/__init__.py +++ b/public_api/shared_schemas/__init__.py @@ -72,13 +72,10 @@ BulkTaskCreate, BulkTaskCreateResult) # User shared_schemas from .user import ( - PermissionBase, PermissionCreate, PermissionUpdate, Permission, - RoleBase, RoleCreate, RoleUpdate, Role, - UserBase, UserCreate, UserUpdate, User, UserInDB, Token, TokenData, - Message, UserSanitizedWithRole, PasswordResetConfirm, UserFilter, - UserActivity, RoleWithUsers, UserPermissions, BulkUserCreate, - BulkUserCreateResult, UserWithPermissions, UserPermissionUpdate, - AllRoles, AllPermissions, UserWithRole, TwoFactorLogin + RoleName, PermissionBase, PermissionCreate, PermissionUpdate, Permission, + RoleBase, RoleCreate, RoleUpdate, Role, UserBase, UserUpdate, UserSanitized, + UserInDB, TwoFactorLogin, Token, Message, UserFilter, UserWithPermissions, + AllRoles, AllPermissions, UserPermissionUpdate, UserCreate, RefreshTokenRequest ) # Warehouse shared_schemas from .warehouse import ( diff --git a/public_api/shared_schemas/asset.py b/public_api/shared_schemas/asset.py index ec12503..e639265 100644 --- a/public_api/shared_schemas/asset.py +++ b/public_api/shared_schemas/asset.py @@ -1,5 +1,4 @@ # /server/app/shared_schemas/asset.py -from typing import Optional, List from pydantic import BaseModel @@ -10,7 +9,7 @@ class AssetBase(BaseModel): serial_number: str purchase_date: int status: str - location_id: Optional[int] = None + location_id: int | None = None class AssetCreate(AssetBase): @@ -18,12 +17,12 @@ class AssetCreate(AssetBase): class AssetUpdate(BaseModel): - asset_type: Optional[str] = None - asset_name: Optional[str] = None - serial_number: Optional[str] = None - purchase_date: Optional[int] = None - status: Optional[str] = None - location_id: Optional[int] = None + asset_type: str | None = None + asset_name: str | None = None + serial_number: str | None = None + purchase_date: int | None = None + status: str | None = None + location_id: int | None = None class Asset(AssetBase): @@ -37,9 +36,9 @@ class AssetMaintenanceBase(BaseModel): asset_id: int maintenance_type: str scheduled_date: int - completed_date: Optional[int] = None - performed_by: Optional[int] = None - notes: Optional[str] = None + completed_date: int | None = None + performed_by: int | None = None + notes: str | None = None class AssetMaintenanceCreate(AssetMaintenanceBase): @@ -47,11 +46,11 @@ class AssetMaintenanceCreate(AssetMaintenanceBase): class AssetMaintenanceUpdate(BaseModel): - maintenance_type: Optional[str] = None - scheduled_date: Optional[int] = None - completed_date: Optional[int] = None - performed_by: Optional[int] = None - notes: Optional[str] = None + maintenance_type: str | None = None + scheduled_date: int | None = None + completed_date: int | None = None + performed_by: int | None = None + notes: str | None = None class AssetMaintenance(AssetMaintenanceBase): @@ -62,32 +61,32 @@ class Config: class AssetWithMaintenance(Asset): - maintenance_records: List[AssetMaintenance] = [] + maintenance_records: list[AssetMaintenance] = [] class Config: from_attributes = True class AssetFilter(BaseModel): - asset_type: Optional[str] = None - status: Optional[str] = None - purchase_date_from: Optional[int] = None - purchase_date_to: Optional[int] = None - location_id: Optional[int] = None + asset_type: str | None = None + status: str | None = None + purchase_date_from: int | None = None + purchase_date_to: int | None = None + location_id: int | None = None class AssetMaintenanceFilter(BaseModel): - asset_id: Optional[int] = None - maintenance_type: Optional[str] = None - scheduled_date_from: Optional[int] = None - scheduled_date_to: Optional[int] = None - completed_date_from: Optional[int] = None - completed_date_to: Optional[int] = None - performed_by: Optional[int] = None + asset_id: int | None = None + maintenance_type: str | None = None + scheduled_date_from: int | None = None + scheduled_date_to: int | None = None + completed_date_from: int | None = None + completed_date_to: int | None = None + performed_by: int | None = None class AssetWithMaintenanceList(BaseModel): - assets: List[AssetWithMaintenance] + assets: list[AssetWithMaintenance] total: int diff --git a/public_api/shared_schemas/audit.py b/public_api/shared_schemas/audit.py index 83e3a64..244b6f7 100644 --- a/public_api/shared_schemas/audit.py +++ b/public_api/shared_schemas/audit.py @@ -1,9 +1,6 @@ -# /server/app/shared_schemas/audit_log.py -from typing import Optional, List, Dict +from pydantic import BaseModel, ConfigDict -from pydantic import BaseModel - -from .user import User +from .user import UserSanitized class AuditLogBase(BaseModel): @@ -11,8 +8,8 @@ class AuditLogBase(BaseModel): action_type: str table_name: str record_id: int - old_value: Optional[str] = None - new_value: Optional[str] = None + old_value: str | None = None + new_value: str | None = None class AuditLogCreate(AuditLogBase): @@ -23,24 +20,22 @@ class AuditLog(AuditLogBase): id: int timestamp: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class AuditLogWithUser(AuditLog): - user: User + user: UserSanitized - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class AuditLogFilter(BaseModel): - user_id: Optional[int] = None - action_type: Optional[str] = None - table_name: Optional[str] = None - record_id: Optional[int] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + user_id: int | None = None + action_type: str | None = None + table_name: str | None = None + record_id: int | None = None + date_from: int | None = None + date_to: int | None = None class UserActivitySummary(BaseModel): @@ -51,16 +46,16 @@ class UserActivitySummary(BaseModel): class AuditSummary(BaseModel): total_logs: int - logs_by_action: Dict[str, int] - logs_by_table: Dict[str, int] - most_active_users: List[UserActivitySummary] + logs_by_action: dict[str, int] + logs_by_table: dict[str, int] + most_active_users: list[UserActivitySummary] class AuditLogExport(BaseModel): - logs: List[AuditLog] + logs: list[AuditLog] export_timestamp: int class AuditLogList(BaseModel): - logs: List[AuditLog] + logs: list[AuditLog] total: int diff --git a/public_api/shared_schemas/inventory.py b/public_api/shared_schemas/inventory.py index 317aee1..f52f88c 100644 --- a/public_api/shared_schemas/inventory.py +++ b/public_api/shared_schemas/inventory.py @@ -1,12 +1,11 @@ # /server/app/shared_schemas/inventory.py -from typing import Optional, List, Dict, Union from pydantic import BaseModel, constr, Field class ProductCategoryBase(BaseModel): name: str - parent_category_id: Optional[int] = None + parent_category_id: int | None = None class ProductCategoryCreate(ProductCategoryBase): @@ -14,8 +13,8 @@ class ProductCategoryCreate(ProductCategoryBase): class ProductCategoryUpdate(BaseModel): - name: Optional[str] = None - parent_category_id: Optional[int] = None + name: str | None = None + parent_category_id: int | None = None class ProductCategory(ProductCategoryBase): @@ -28,12 +27,12 @@ class Config: class ProductBase(BaseModel): sku: str name: str - description: Optional[str] = None + description: str | None = None category_id: int - unit_of_measure: Optional[str] = None - weight: Optional[float] = None - dimensions: Optional[str] = None - barcode: Optional[str] = None + unit_of_measure: str | None = None + weight: float | None = None + dimensions: str | None = None + barcode: str | None = None price: float @@ -42,15 +41,15 @@ class ProductCreate(ProductBase): class ProductUpdate(BaseModel): - sku: Optional[str] = None - name: Optional[str] = None - description: Optional[str] = None - category_id: Optional[int] = None - unit_of_measure: Optional[str] = None - weight: Optional[float] = None - dimensions: Optional[str] = None - barcode: Optional[str] = None - price: Optional[float] = None + sku: str | None = None + name: str | None = None + description: str | None = None + category_id: int | None = None + unit_of_measure: str | None = None + weight: float | None = None + dimensions: str | None = None + barcode: str | None = None + price: float | None = None class Product(ProductBase): @@ -64,7 +63,7 @@ class InventoryBase(BaseModel): product_id: int location_id: int quantity: int - expiration_date: Optional[int] = None + expiration_date: int | None = None class InventoryCreate(InventoryBase): @@ -72,10 +71,10 @@ class InventoryCreate(InventoryBase): class InventoryUpdate(BaseModel): - product_id: Optional[int] = None - location_id: Optional[int] = None - quantity: Optional[int] = None - expiration_date: Optional[int] = None + product_id: int | None = None + location_id: int | None = None + quantity: int | None = None + expiration_date: int | None = None class Inventory(InventoryBase): @@ -88,11 +87,11 @@ class Config: class LocationBase(BaseModel): name: str = Field(..., max_length=100) - zone_id: Optional[int] = None - aisle: Optional[str] = Field(None, max_length=50) - rack: Optional[str] = Field(None, max_length=50) - shelf: Optional[str] = Field(None, max_length=50) - bin: Optional[str] = Field(None, max_length=50) + zone_id: int | None = None + aisle: str | None = Field(None, max_length=50) + rack: str | None = Field(None, max_length=50) + shelf: str | None = Field(None, max_length=50) + bin: str | None = Field(None, max_length=50) capacity: int = 0 @@ -101,13 +100,13 @@ class LocationCreate(LocationBase): class LocationUpdate(BaseModel): - name: Optional[str] = Field(None, max_length=100) - zone_id: Optional[int] = None - aisle: Optional[str] = Field(None, max_length=50) - rack: Optional[str] = Field(None, max_length=50) - shelf: Optional[str] = Field(None, max_length=50) - bin: Optional[str] = Field(None, max_length=50) - capacity: Optional[int] = 0 + name: str | None = Field(None, max_length=100) + zone_id: int | None = None + aisle: str | None = Field(None, max_length=50) + rack: str | None = Field(None, max_length=50) + shelf: str | None = Field(None, max_length=50) + bin: str | None = Field(None, max_length=50) + capacity: int | None = 0 class Location(LocationBase): @@ -119,7 +118,7 @@ class Config: class ZoneBase(BaseModel): name: str - description: Optional[str] = None + description: str | None = None class ZoneCreate(ZoneBase): @@ -127,8 +126,8 @@ class ZoneCreate(ZoneBase): class ZoneUpdate(BaseModel): - name: Optional[str] = None - description: Optional[str] = None + name: str | None = None + description: str | None = None class Zone(ZoneBase): @@ -139,38 +138,38 @@ class Config: class ProductWithInventory(Product): - inventory_items: List[Inventory] = [] + inventory_items: list[Inventory] = [] class Config: from_attributes = True class LocationWithInventory(Location): - inventory_items: List[Inventory] = [] + inventory_items: list[Inventory] = [] class ZoneWithLocations(Zone): - locations: List[Location] = [] + locations: list[Location] = [] class ProductFilter(BaseModel): - name: Optional[str] = None - category_id: Optional[int] = None - sku: Optional[str] = None - barcode: Optional[str] = None + name: str | None = None + category_id: int | None = None + sku: str | None = None + barcode: str | None = None class LocationFilter(BaseModel): - name: Optional[str] = None - zone_id: Optional[int] = None - aisle: Optional[str] = None - rack: Optional[str] = None - shelf: Optional[str] = None - bin: Optional[str] = None + name: str | None = None + zone_id: int | None = None + aisle: str | None = None + rack: str | None = None + shelf: str | None = None + bin: str | None = None class ZoneFilter(BaseModel): - name: Optional[str] = None + name: str | None = None class BarcodeData(BaseModel): @@ -184,19 +183,19 @@ class InventoryTransfer(BaseModel): class ProductWithCategoryAndInventory(ProductWithInventory): - category: Optional[ProductCategory] = None + category: ProductCategory | None = None pass class InventoryReport(BaseModel): total_products: int total_quantity: int - low_stock_items: List[ProductWithInventory] - out_of_stock_items: List[Product] + low_stock_items: list[ProductWithInventory] + out_of_stock_items: list[Product] class WarehouseLayout(BaseModel): - zones: List[ZoneWithLocations] + zones: list[ZoneWithLocations] class InventoryMovement(BaseModel): @@ -223,7 +222,7 @@ class StocktakeItem(BaseModel): class StocktakeCreate(BaseModel): location_id: int - items: List[StocktakeItem] + items: list[StocktakeItem] class StocktakeDiscrepancy(BaseModel): @@ -236,19 +235,19 @@ class StocktakeDiscrepancy(BaseModel): class StocktakeResult(BaseModel): location_id: int total_items: int - discrepancies: List[StocktakeDiscrepancy] + discrepancies: list[StocktakeDiscrepancy] accuracy_percentage: float class ABCCategory(BaseModel): category: str - products: List[Product] + products: list[Product] value_percentage: float item_percentage: float class ABCAnalysisResult(BaseModel): - categories: List[ABCCategory] + categories: list[ABCCategory] class InventoryLocationSuggestion(BaseModel): @@ -259,24 +258,24 @@ class InventoryLocationSuggestion(BaseModel): class BulkImportData(BaseModel): - items: List[InventoryCreate] + items: list[InventoryCreate] class BulkImportResult(BaseModel): success_count: int failure_count: int - errors: List[str] + errors: list[str] class StorageUtilization(BaseModel): total_capacity: float used_capacity: float utilization_percentage: float - zone_utilization: List[Dict[str, Union[str, int, float]]] + zone_utilization: list[dict[str, str | int | float]] class InventorySummary(BaseModel): - category_quantities: Dict[str, int] + category_quantities: dict[str, int] total_items: int total_categories: int @@ -290,17 +289,17 @@ class InventoryWithDetails(Inventory): class InventoryList(BaseModel): - items: List[InventoryWithDetails] + items: list[InventoryWithDetails] total: int class InventoryFilter(BaseModel): - product_id: Optional[int] = None - location_id: Optional[int] = None - sku: Optional[str] = None - name: Optional[str] = None - quantity_min: Optional[int] = None - quantity_max: Optional[int] = None + product_id: int | None = None + location_id: int | None = None + sku: str | None = None + name: str | None = None + quantity_min: int | None = None + quantity_max: int | None = None class InventoryTrendItem(BaseModel): diff --git a/public_api/shared_schemas/notification.py b/public_api/shared_schemas/notification.py index 6c32dcf..05ecefa 100644 --- a/public_api/shared_schemas/notification.py +++ b/public_api/shared_schemas/notification.py @@ -1,5 +1,4 @@ # /public_api/shared_schemas/notification.py -from typing import Optional from pydantic import BaseModel @@ -14,7 +13,7 @@ class NotificationCreate(NotificationBase): class NotificationUpdate(BaseModel): - is_read: Optional[bool] = None + is_read: bool | None = None class Notification(NotificationBase): diff --git a/public_api/shared_schemas/order.py b/public_api/shared_schemas/order.py index 25d47bd..69d7a2e 100644 --- a/public_api/shared_schemas/order.py +++ b/public_api/shared_schemas/order.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional, List from pydantic import BaseModel @@ -25,10 +24,10 @@ class OrderItemCreate(OrderItemBase): class OrderItemUpdate(BaseModel): - id: Optional[int] = None - product_id: Optional[int] = None - quantity: Optional[int] = None - unit_price: Optional[float] = None + id: int | None = None + product_id: int | None = None + quantity: int | None = None + unit_price: float | None = None class OrderItem(OrderItemBase): @@ -43,39 +42,39 @@ class OrderBase(BaseModel): customer_id: int status: OrderStatus total_amount: float - shipping_name: Optional[str] = None - shipping_address_line1: Optional[str] = None - shipping_city: Optional[str] = None - shipping_state: Optional[str] = None - shipping_postal_code: Optional[str] = None - shipping_country: Optional[str] = None - shipping_phone: Optional[str] = None - ship_date: Optional[int] = None + shipping_name: str | None = None + shipping_address_line1: str | None = None + shipping_city: str | None = None + shipping_state: str | None = None + shipping_postal_code: str | None = None + shipping_country: str | None = None + shipping_phone: str | None = None + ship_date: int | None = None class OrderCreate(OrderBase): - items: List[OrderItemCreate] + items: list[OrderItemCreate] class OrderUpdate(BaseModel): - customer_id: Optional[int] = None - status: Optional[OrderStatus] = None - total_amount: Optional[float] = None - shipping_name: Optional[str] = None - shipping_address_line1: Optional[str] = None - shipping_city: Optional[str] = None - shipping_state: Optional[str] = None - shipping_postal_code: Optional[str] = None - shipping_country: Optional[str] = None - shipping_phone: Optional[str] = None - ship_date: Optional[int] = None - items: Optional[List[OrderItemUpdate]] = None + customer_id: int | None = None + status: OrderStatus | None = None + total_amount: float | None = None + shipping_name: str | None = None + shipping_address_line1: str | None = None + shipping_city: str | None = None + shipping_state: str | None = None + shipping_postal_code: str | None = None + shipping_country: str | None = None + shipping_phone: str | None = None + ship_date: int | None = None + items: list[OrderItemUpdate] | None = None class Order(OrderBase): id: int order_date: int - order_items: List[OrderItem] = [] + order_items: list[OrderItem] = [] class Config: from_attributes = True @@ -83,9 +82,9 @@ class Config: class CustomerBase(BaseModel): name: str - email: Optional[str] = None - phone: Optional[str] = None - address: Optional[str] = None + email: str | None = None + phone: str | None = None + address: str | None = None class CustomerCreate(CustomerBase): @@ -93,10 +92,10 @@ class CustomerCreate(CustomerBase): class CustomerUpdate(BaseModel): - name: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - address: Optional[str] = None + name: str | None = None + email: str | None = None + phone: str | None = None + address: str | None = None class Customer(CustomerBase): @@ -117,9 +116,9 @@ class POItemCreate(POItemBase): class POItemUpdate(BaseModel): - product_id: Optional[int] = None - quantity: Optional[int] = None - unit_price: Optional[float] = None + product_id: int | None = None + quantity: int | None = None + unit_price: float | None = None class POItem(POItemBase): @@ -133,24 +132,24 @@ class Config: class PurchaseOrderBase(BaseModel): supplier_id: int status: OrderStatus - expected_delivery_date: Optional[int] = None + expected_delivery_date: int | None = None class PurchaseOrderCreate(PurchaseOrderBase): - items: List[POItemCreate] + items: list[POItemCreate] class PurchaseOrderUpdate(BaseModel): - supplier_id: Optional[int] = None - status: Optional[OrderStatus] = None - expected_delivery_date: Optional[int] = None - items: Optional[List[POItemUpdate]] = None + supplier_id: int | None = None + status: OrderStatus | None = None + expected_delivery_date: int | None = None + items: list[POItemUpdate] | None = None class PurchaseOrder(PurchaseOrderBase): id: int order_date: int - po_items: List[POItem] = [] + po_items: list[POItem] = [] class Config: from_attributes = True @@ -158,10 +157,10 @@ class Config: class SupplierBase(BaseModel): name: str - contact_person: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - address: Optional[str] = None + contact_person: str | None = None + email: str | None = None + phone: str | None = None + address: str | None = None class SupplierCreate(SupplierBase): @@ -169,11 +168,11 @@ class SupplierCreate(SupplierBase): class SupplierUpdate(BaseModel): - name: Optional[str] = None - contact_person: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - address: Optional[str] = None + name: str | None = None + contact_person: str | None = None + email: str | None = None + phone: str | None = None + address: str | None = None class Supplier(SupplierBase): @@ -184,12 +183,12 @@ class Config: class OrderFilter(BaseModel): - customer_id: Optional[int] = None - status: Optional[OrderStatus] = None - order_date_from: Optional[int] = None - order_date_to: Optional[int] = None - ship_date_from: Optional[int] = None - ship_date_to: Optional[int] = None + customer_id: int | None = None + status: OrderStatus | None = None + order_date_from: int | None = None + order_date_to: int | None = None + ship_date_from: int | None = None + ship_date_to: int | None = None class OrderSummary(BaseModel): @@ -199,30 +198,30 @@ class OrderSummary(BaseModel): class CustomerFilter(BaseModel): - name: Optional[str] = None - email: Optional[str] = None + name: str | None = None + email: str | None = None class PurchaseOrderFilter(BaseModel): - supplier_id: Optional[int] = None - status: Optional[OrderStatus] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + supplier_id: int | None = None + status: OrderStatus | None = None + date_from: int | None = None + date_to: int | None = None class SupplierFilter(BaseModel): - name: Optional[str] = None - contact_person: Optional[str] = None + name: str | None = None + contact_person: str | None = None class OrderWithDetails(Order): customer: Customer - order_items: List[OrderItem] + order_items: list[OrderItem] class PurchaseOrderWithDetails(PurchaseOrder): supplier: Supplier - po_items: List[POItem] + po_items: list[POItem] class ShippingInfo(BaseModel): @@ -246,10 +245,10 @@ class OrderProcessingTimes(BaseModel): class BulkOrderImportData(BaseModel): - orders: List[OrderCreate] + orders: list[OrderCreate] class BulkOrderImportResult(BaseModel): success_count: int failure_count: int - errors: List[str] + errors: list[str] diff --git a/public_api/shared_schemas/quality.py b/public_api/shared_schemas/quality.py index b9545ce..a975fc2 100644 --- a/public_api/shared_schemas/quality.py +++ b/public_api/shared_schemas/quality.py @@ -1,5 +1,4 @@ # /server/app/shared_schemas/quality.py -from typing import Optional from pydantic import BaseModel @@ -10,7 +9,7 @@ class QualityCheckBase(BaseModel): product_id: int performed_by: int result: str - notes: Optional[str] = None + notes: str | None = None class QualityCheckCreate(QualityCheckBase): @@ -18,10 +17,10 @@ class QualityCheckCreate(QualityCheckBase): class QualityCheckUpdate(BaseModel): - product_id: Optional[int] = None - performed_by: Optional[int] = None - result: Optional[str] = None - notes: Optional[str] = None + product_id: int | None = None + performed_by: int | None = None + result: str | None = None + notes: str | None = None class QualityCheck(QualityCheckBase): @@ -37,11 +36,11 @@ class QualityCheckWithProduct(QualityCheck): class QualityCheckFilter(BaseModel): - product_id: Optional[int] = None - performed_by: Optional[int] = None - result: Optional[str] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + product_id: int | None = None + performed_by: int | None = None + result: str | None = None + date_from: int | None = None + date_to: int | None = None class QualityMetrics(BaseModel): @@ -61,9 +60,9 @@ class QualityStandardCreate(QualityStandardBase): class QualityStandardUpdate(BaseModel): - product_id: Optional[int] = None - criteria: Optional[str] = None - acceptable_range: Optional[str] = None + product_id: int | None = None + criteria: str | None = None + acceptable_range: str | None = None class QualityStandard(QualityStandardBase): @@ -84,16 +83,16 @@ class QualityAlertCreate(QualityAlertBase): class QualityAlertUpdate(BaseModel): - product_id: Optional[int] = None - alert_type: Optional[str] = None - description: Optional[str] = None - resolved_at: Optional[int] = None + product_id: int | None = None + alert_type: str | None = None + description: str | None = None + resolved_at: int | None = None class QualityAlert(QualityAlertBase): id: int created_at: int - resolved_at: Optional[int] = None + resolved_at: int | None = None class Config: from_attributes = True @@ -122,15 +121,15 @@ class ProductDefectRate(BaseModel): class QualityStandardFilter(BaseModel): - product_id: Optional[int] = None + product_id: int | None = None class QualityAlertFilter(BaseModel): - product_id: Optional[int] = None - alert_type: Optional[str] = None - resolved: Optional[bool] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + product_id: int | None = None + alert_type: str | None = None + resolved: bool | None = None + date_from: int | None = None + date_to: int | None = None class QualityCheckSummary(BaseModel): diff --git a/public_api/shared_schemas/reports.py b/public_api/shared_schemas/reports.py index eb0be55..073b63c 100644 --- a/public_api/shared_schemas/reports.py +++ b/public_api/shared_schemas/reports.py @@ -1,6 +1,5 @@ # /server/app/shared_schemas/reports.py from enum import Enum -from typing import List, Optional from pydantic import BaseModel @@ -15,7 +14,7 @@ class InventoryItem(BaseModel): class InventorySummaryReport(BaseModel): total_items: int total_value: float - items: List[InventoryItem] + items: list[InventoryItem] class OrderSummary(BaseModel): @@ -39,7 +38,7 @@ class WarehousePerformanceMetric(BaseModel): class WarehousePerformanceReport(BaseModel): start_date: int end_date: int - metrics: List[WarehousePerformanceMetric] + metrics: list[WarehousePerformanceMetric] class TrendDirection(str, Enum): @@ -56,7 +55,7 @@ class KPIMetric(BaseModel): class KPIDashboard(BaseModel): date: int - metrics: List[KPIMetric] + metrics: list[KPIMetric] class ProductPerformance(BaseModel): @@ -69,7 +68,7 @@ class ProductPerformance(BaseModel): class TopSellingProductsReport(BaseModel): start_date: int end_date: int - products: List[ProductPerformance] + products: list[ProductPerformance] class SupplierPerformance(BaseModel): @@ -83,7 +82,7 @@ class SupplierPerformance(BaseModel): class SupplierPerformanceReport(BaseModel): start_date: int end_date: int - suppliers: List[SupplierPerformance] + suppliers: list[SupplierPerformance] class StockMovement(BaseModel): @@ -97,7 +96,7 @@ class StockMovement(BaseModel): class StockMovementReport(BaseModel): start_date: int end_date: int - movements: List[StockMovement] + movements: list[StockMovement] class PickingEfficiency(BaseModel): @@ -111,7 +110,7 @@ class PickingEfficiency(BaseModel): class PickingEfficiencyReport(BaseModel): start_date: int end_date: int - pickers: List[PickingEfficiency] + pickers: list[PickingEfficiency] class StorageUtilization(BaseModel): @@ -124,7 +123,7 @@ class StorageUtilization(BaseModel): class StorageUtilizationReport(BaseModel): date: int - zones: List[StorageUtilization] + zones: list[StorageUtilization] class ReturnsAnalysis(BaseModel): @@ -138,7 +137,7 @@ class ReturnsReport(BaseModel): end_date: int total_returns: int return_rate: float - reasons: List[ReturnsAnalysis] + reasons: list[ReturnsAnalysis] class CustomReport(BaseModel): @@ -147,7 +146,7 @@ class CustomReport(BaseModel): query: str parameters: dict created_at: int - last_run: Optional[int] + last_run: int | None class ReportFrequency(str, Enum): @@ -160,4 +159,4 @@ class ReportSchedule(BaseModel): report_id: int frequency: ReportFrequency next_run: int - recipients: List[str] + recipients: list[str] diff --git a/public_api/shared_schemas/task.py b/public_api/shared_schemas/task.py index 235fbcf..1f8e236 100644 --- a/public_api/shared_schemas/task.py +++ b/public_api/shared_schemas/task.py @@ -1,10 +1,9 @@ # /server/app/shared_schemas/task.py from enum import Enum -from typing import Optional, List from pydantic import BaseModel -from .user import User +from .user import UserSanitized class TaskStatus(str, Enum): @@ -41,12 +40,12 @@ class TaskCreate(TaskBase): class TaskUpdate(BaseModel): - task_type: Optional[TaskType] = None - description: Optional[str] = None - assigned_to: Optional[int] = None - due_date: Optional[int] = None - priority: Optional[TaskPriority] = None - status: Optional[TaskStatus] = None + task_type: TaskType | None = None + description: str | None = None + assigned_to: int | None = None + due_date: int | None = None + priority: TaskPriority | None = None + status: TaskStatus | None = None class Task(TaskBase): @@ -58,16 +57,16 @@ class Config: class TaskWithAssignee(Task): - assigned_user: Optional[User] = None + assigned_user: UserSanitized | None = None class TaskFilter(BaseModel): - task_type: Optional[TaskType] = None - assigned_to: Optional[int] = None - priority: Optional[TaskPriority] = None - status: Optional[TaskStatus] = None - due_date_from: Optional[int] = None - due_date_to: Optional[int] = None + task_type: TaskType | None = None + assigned_to: int | None = None + priority: TaskPriority | None = None + status: TaskStatus | None = None + due_date_from: int | None = None + due_date_to: int | None = None class TaskCommentBase(BaseModel): @@ -104,7 +103,7 @@ class UserTaskSummary(BaseModel): class TaskWithComments(Task): - comments: List[TaskComment] = [] + comments: list[TaskComment] = [] class TaskPriorityUpdate(BaseModel): @@ -154,14 +153,14 @@ class TaskAnalytics(BaseModel): total_tasks: int completion_rate: float average_completion_time: float - type_distribution: List[TaskTypeDistribution] + type_distribution: list[TaskTypeDistribution] class BulkTaskCreate(BaseModel): - tasks: List[TaskCreate] + tasks: list[TaskCreate] class BulkTaskCreateResult(BaseModel): success_count: int failure_count: int - errors: List[str] + errors: list[str] diff --git a/public_api/shared_schemas/user.py b/public_api/shared_schemas/user.py index 6215da7..8ebbffe 100644 --- a/public_api/shared_schemas/user.py +++ b/public_api/shared_schemas/user.py @@ -1,14 +1,13 @@ # /server/app/shared_schemas/user.py from enum import Enum -from typing import Optional, List -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, Field, ConfigDict class RoleName(str, Enum): - admin = "admin" - manager = "manager" - user = "user" + ADMIN = "admin" + MANAGER = "manager" + USER = "user" class PermissionBase(BaseModel): @@ -23,14 +22,13 @@ class PermissionCreate(PermissionBase): class PermissionUpdate(PermissionBase): - permission_name: Optional[str] = None + permission_name: str | None = None class Permission(PermissionBase): id: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class RoleBase(BaseModel): @@ -38,20 +36,19 @@ class RoleBase(BaseModel): class RoleCreate(RoleBase): - permissions: List[int] + permissions: list[int] class RoleUpdate(BaseModel): - role_name: Optional[RoleName] = None - permissions: Optional[List[int]] = None + role_name: RoleName | None = None + permissions: list[int] | None = None class Role(RoleBase): id: int - permissions: List[Permission] = [] + permissions: list[Permission] = [] - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class UserBase(BaseModel): @@ -61,45 +58,38 @@ class UserBase(BaseModel): role_id: int two_factor_auth_enabled: bool = False + model_config = ConfigDict(from_attributes=True, extra='ignore') + class UserCreate(UserBase): - password: str - two_factor_auth_secret: Optional[str] = None + password: str = Field(..., min_length=8) + two_factor_auth_secret: str | None = None class UserUpdate(BaseModel): - username: Optional[str] = None - email: Optional[EmailStr] = None - is_active: Optional[bool] = None - role_id: Optional[int] = None - password: Optional[str] = None - two_factor_auth_enabled: Optional[bool] = None - two_factor_auth_secret: Optional[str] = None + username: str | None = None + email: EmailStr | None = None + is_active: bool | None = None + role_id: int | None = None + password: str | None = Field(None, min_length=8) + two_factor_auth_enabled: bool | None = None + two_factor_auth_secret: str | None = None -class UserSanitizedWithRole(UserBase): +class UserSanitized(UserBase): id: int created_at: int - last_login: Optional[int] = None - two_factor_auth_enabled: bool + last_login: int | None = None role: Role - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) -class User(UserBase): - id: int - created_at: int - last_login: Optional[int] = None +class UserInDB(UserSanitized): password: str - password_reset_token: Optional[str] = None - password_reset_expiration: Optional[int] = None - two_factor_auth_enabled: bool - two_factor_auth_secret: Optional[str] = None - - class Config: - from_attributes = True + password_reset_token: str | None = None + password_reset_expiration: int | None = None + two_factor_auth_secret: str | None = None class TwoFactorLogin(BaseModel): @@ -108,94 +98,43 @@ class TwoFactorLogin(BaseModel): two_factor_code: str -class UserInDB(User): - password: str - - class Token(BaseModel): access_token: str + refresh_token: str token_type: str + expires_in: int # This is the number of seconds until the access token expires class TokenData(BaseModel): - username: Optional[str] = None + username: str | None = None class Message(BaseModel): message: str -class RolePermission(BaseModel): - role_id: int - permission_id: int - - class Config: - from_attributes = True - - -class UserWithRole(User): - role: Role - - -class PasswordReset(BaseModel): - email: EmailStr - - -class PasswordResetConfirm(BaseModel): - token: str - new_password: str - - class UserFilter(BaseModel): - username: Optional[str] = None - email: Optional[str] = None - is_active: Optional[bool] = None - role_id: Optional[int] = None - - -class UserActivity(BaseModel): - user_id: int - username: str - last_login: Optional[int] - total_logins: int - total_actions: int + username: str | None = None + email: str | None = None + is_active: bool | None = None + role_id: int | None = None -class RoleWithUsers(Role): - users: List[User] = [] - - -class UserPermissions(BaseModel): - user_id: int - username: str - permissions: List[Permission] - - -class BulkUserCreate(BaseModel): - users: List[UserCreate] - - -class BulkUserCreateResult(BaseModel): - success_count: int - failure_count: int - errors: List[str] +class UserWithPermissions(UserSanitized): + permissions: list[Permission] class UserPermissionUpdate(BaseModel): - user_id: int - permissions: List[int] - - -class UserWithPermissions(UserSanitizedWithRole): - permissions: List[Permission] - - class Config: - from_attributes = True + permissions: list[int] class AllRoles(BaseModel): - roles: List[Role] + roles: list[Role] class AllPermissions(BaseModel): - permissions: List[Permission] + permissions: list[Permission] + + +class RefreshTokenRequest(BaseModel): + refresh_token: str diff --git a/public_api/shared_schemas/warehouse.py b/public_api/shared_schemas/warehouse.py index aafb5d5..b594bfc 100644 --- a/public_api/shared_schemas/warehouse.py +++ b/public_api/shared_schemas/warehouse.py @@ -1,8 +1,10 @@ -from typing import Optional, List from enum import Enum + from pydantic import BaseModel + from public_api.shared_schemas import Order + class ShipmentStatus(str, Enum): PENDING = "Pending" IN_TRANSIT = "In Transit" @@ -21,10 +23,10 @@ class PickListItemCreate(PickListItemBase): class PickListItemUpdate(BaseModel): - product_id: Optional[int] = None - location_id: Optional[int] = None - quantity: Optional[int] = None - picked_quantity: Optional[int] = None + product_id: int | None = None + location_id: int | None = None + quantity: int | None = None + picked_quantity: int | None = None class PickListItem(PickListItemBase): @@ -41,20 +43,20 @@ class PickListBase(BaseModel): class PickListCreate(PickListBase): - items: List[PickListItemCreate] + items: list[PickListItemCreate] class PickListUpdate(BaseModel): - order_id: Optional[int] = None - status: Optional[str] = None - items: Optional[List[PickListItemUpdate]] = None + order_id: int | None = None + status: str | None = None + items: list[PickListItemUpdate] | None = None class PickList(PickListBase): pick_list_id: int created_at: int - completed_at: Optional[int] = None - items: List[PickListItem] = [] + completed_at: int | None = None + items: list[PickListItem] = [] class Config: from_attributes = True @@ -71,9 +73,9 @@ class ReceiptItemCreate(ReceiptItemBase): class ReceiptItemUpdate(BaseModel): - product_id: Optional[int] = None - quantity_received: Optional[int] = None - location_id: Optional[int] = None + product_id: int | None = None + quantity_received: int | None = None + location_id: int | None = None class ReceiptItem(ReceiptItemBase): @@ -90,19 +92,19 @@ class ReceiptBase(BaseModel): class ReceiptCreate(ReceiptBase): - items: List[ReceiptItemCreate] + items: list[ReceiptItemCreate] class ReceiptUpdate(BaseModel): - po_id: Optional[int] = None - status: Optional[str] = None - items: Optional[List[ReceiptItemUpdate]] = None + po_id: int | None = None + status: str | None = None + items: list[ReceiptItemUpdate] | None = None class Receipt(ReceiptBase): receipt_id: int received_date: int - items: List[ReceiptItem] = [] + items: list[ReceiptItem] = [] class Config: from_attributes = True @@ -111,29 +113,29 @@ class Config: class ShipmentBase(BaseModel): order_id: int carrier_id: int - tracking_number: Optional[str] = None + tracking_number: str | None = None status: ShipmentStatus - label_id: Optional[str] = None - label_download_url: Optional[str] = None + label_id: str | None = None + label_download_url: str | None = None class ShipmentCreate(ShipmentBase): - ship_date: Optional[int] = None + ship_date: int | None = None class ShipmentUpdate(BaseModel): - order_id: Optional[int] = None - carrier_id: Optional[int] = None - tracking_number: Optional[str] = None - status: Optional[ShipmentStatus] = None - ship_date: Optional[int] = None - label_id: Optional[str] = None - label_download_url: Optional[str] = None + order_id: int | None = None + carrier_id: int | None = None + tracking_number: str | None = None + status: ShipmentStatus | None = None + ship_date: int | None = None + label_id: str | None = None + label_download_url: str | None = None class Shipment(ShipmentBase): id: int - ship_date: Optional[int] = None + ship_date: int | None = None class Config: from_attributes = True @@ -141,7 +143,7 @@ class Config: class CarrierBase(BaseModel): name: str - contact_info: Optional[str] = None + contact_info: str | None = None class CarrierCreate(CarrierBase): @@ -149,8 +151,8 @@ class CarrierCreate(CarrierBase): class CarrierUpdate(BaseModel): - name: Optional[str] = None - contact_info: Optional[str] = None + name: str | None = None + contact_info: str | None = None class Carrier(CarrierBase): @@ -161,25 +163,25 @@ class Config: class PickListFilter(BaseModel): - status: Optional[str] = None - order_id: Optional[int] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + status: str | None = None + order_id: int | None = None + date_from: int | None = None + date_to: int | None = None class ReceiptFilter(BaseModel): - status: Optional[str] = None - po_id: Optional[int] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + status: str | None = None + po_id: int | None = None + date_from: int | None = None + date_to: int | None = None class ShipmentFilter(BaseModel): - status: Optional[ShipmentStatus] = None - order_id: Optional[int] = None - carrier_id: Optional[int] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + status: ShipmentStatus | None = None + order_id: int | None = None + carrier_id: int | None = None + date_from: int | None = None + date_to: int | None = None class WarehouseStats(BaseModel): @@ -206,7 +208,7 @@ class LocationInventoryUpdate(BaseModel): class OptimizedPickingRoute(BaseModel): pick_list_id: int - optimized_route: List[PickListItem] + optimized_route: list[PickListItem] class PickingPerformance(BaseModel): @@ -218,7 +220,7 @@ class PickingPerformance(BaseModel): class QualityCheckCreate(BaseModel): product_id: int result: str - notes: Optional[str] = None + notes: str | None = None class ReceiptDiscrepancy(BaseModel): @@ -246,8 +248,8 @@ class ShipmentTracking(BaseModel): shipment_id: int tracking_number: str current_status: str - estimated_delivery_date: Optional[int] - tracking_history: List[dict] + estimated_delivery_date: int | None + tracking_history: list[dict] class InventoryMovementBase(BaseModel): @@ -301,10 +303,10 @@ class YardLocationCreate(YardLocationBase): class YardLocationUpdate(BaseModel): - name: Optional[str] = None - type: Optional[str] = None - status: Optional[str] = None - capacity: Optional[int] = None + name: str | None = None + type: str | None = None + status: str | None = None + capacity: int | None = None class YardLocation(YardLocationBase): @@ -327,36 +329,36 @@ class DockAppointmentCreate(DockAppointmentBase): class DockAppointmentUpdate(BaseModel): - yard_location_id: Optional[int] = None - appointment_time: Optional[int] = None - carrier_id: Optional[int] = None - type: Optional[str] = None - status: Optional[str] = None - actual_arrival_time: Optional[int] = None - actual_departure_time: Optional[int] = None + yard_location_id: int | None = None + appointment_time: int | None = None + carrier_id: int | None = None + type: str | None = None + status: str | None = None + actual_arrival_time: int | None = None + actual_departure_time: int | None = None class DockAppointment(DockAppointmentBase): appointment_id: int - actual_arrival_time: Optional[int] = None - actual_departure_time: Optional[int] = None + actual_arrival_time: int | None = None + actual_departure_time: int | None = None class Config: from_attributes = True class YardLocationFilter(BaseModel): - type: Optional[str] = None - status: Optional[str] = None + type: str | None = None + status: str | None = None class DockAppointmentFilter(BaseModel): - yard_location_id: Optional[int] = None - carrier_id: Optional[int] = None - type: Optional[str] = None - status: Optional[str] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + yard_location_id: int | None = None + carrier_id: int | None = None + type: str | None = None + status: str | None = None + date_from: int | None = None + date_to: int | None = None class YardManagementStats(BaseModel): @@ -368,5 +370,5 @@ class YardManagementStats(BaseModel): class ShipmentWithDetails(Shipment): - order: Optional[Order] = None - carrier: Optional[Carrier] = None + order: Order | None = None + carrier: Carrier | None = None diff --git a/public_api/shared_schemas/yard.py b/public_api/shared_schemas/yard.py index be23bfc..d9399f3 100644 --- a/public_api/shared_schemas/yard.py +++ b/public_api/shared_schemas/yard.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional, List from pydantic import BaseModel @@ -27,10 +26,10 @@ class YardLocationCreate(YardLocationBase): class YardLocationUpdate(BaseModel): - name: Optional[str] = None - type: Optional[YardLocationType] = None - status: Optional[YardLocationStatus] = None - capacity: Optional[int] = None + name: str | None = None + type: YardLocationType | None = None + status: YardLocationStatus | None = None + capacity: int | None = None class YardLocation(YardLocationBase): @@ -46,8 +45,8 @@ class DockAppointmentBase(BaseModel): carrier_id: int type: YardLocationType status: YardLocationStatus - actual_arrival_time: Optional[int] = None - actual_departure_time: Optional[int] = None + actual_arrival_time: int | None = None + actual_departure_time: int | None = None class DockAppointmentCreate(DockAppointmentBase): @@ -55,13 +54,13 @@ class DockAppointmentCreate(DockAppointmentBase): class DockAppointmentUpdate(BaseModel): - yard_location_id: Optional[int] = None - appointment_time: Optional[int] = None - carrier_id: Optional[int] = None - type: Optional[YardLocationType] = None - status: Optional[YardLocationStatus] = None - actual_arrival_time: Optional[int] = None - actual_departure_time: Optional[int] = None + yard_location_id: int | None = None + appointment_time: int | None = None + carrier_id: int | None = None + type: YardLocationType | None = None + status: YardLocationStatus | None = None + actual_arrival_time: int | None = None + actual_departure_time: int | None = None class DockAppointment(DockAppointmentBase): @@ -72,22 +71,22 @@ class Config: class YardLocationWithAppointments(YardLocation): - appointments: List[DockAppointment] = [] + appointments: list[DockAppointment] = [] class YardLocationFilter(BaseModel): - name: Optional[str] = None - type: Optional[YardLocationType] = None - status: Optional[YardLocationStatus] = None + name: str | None = None + type: YardLocationType | None = None + status: YardLocationStatus | None = None class DockAppointmentFilter(BaseModel): - yard_location_id: Optional[int] = None - carrier_id: Optional[int] = None - type: Optional[YardLocationType] = None - status: Optional[YardLocationStatus] = None - date_from: Optional[int] = None - date_to: Optional[int] = None + yard_location_id: int | None = None + carrier_id: int | None = None + type: YardLocationType | None = None + status: YardLocationStatus | None = None + date_from: int | None = None + date_to: int | None = None class YardStats(BaseModel): @@ -114,7 +113,7 @@ class YardUtilizationReport(BaseModel): total_capacity: int total_utilization: int utilization_percentage: float - location_breakdown: List[YardLocationCapacity] + location_breakdown: list[YardLocationCapacity] class CarrierPerformance(BaseModel): @@ -131,7 +130,7 @@ class YardLocationOccupancy(BaseModel): yard_location_id: int name: str occupied: bool - current_appointment: Optional[DockAppointment] = None + current_appointment: DockAppointment | None = None class YardOverview(BaseModel): @@ -139,18 +138,18 @@ class YardOverview(BaseModel): occupied_locations: int available_locations: int utilization_percentage: float - locations: List[YardLocationOccupancy] + locations: list[YardLocationOccupancy] class AppointmentScheduleConflict(BaseModel): - conflicting_appointments: List[DockAppointment] - suggested_time_slots: List[int] + conflicting_appointments: list[DockAppointment] + suggested_time_slots: list[int] class CarrierSchedule(BaseModel): carrier_id: int carrier_name: str - appointments: List[DockAppointment] + appointments: list[DockAppointment] class YardLocationTypeDistribution(BaseModel): @@ -162,16 +161,16 @@ class YardLocationTypeDistribution(BaseModel): class YardAnalytics(BaseModel): total_locations: int average_utilization: float - peak_hours: List[int] - type_distribution: List[YardLocationTypeDistribution] - carrier_performance: List[CarrierPerformance] + peak_hours: list[int] + type_distribution: list[YardLocationTypeDistribution] + carrier_performance: list[CarrierPerformance] class BulkAppointmentCreate(BaseModel): - appointments: List[DockAppointmentCreate] + appointments: list[DockAppointmentCreate] class BulkAppointmentCreateResult(BaseModel): success_count: int failure_count: int - errors: List[str] + errors: list[str] diff --git a/server/app/api/deps.py b/server/app/api/deps.py index 694c7f7..796a404 100644 --- a/server/app/api/deps.py +++ b/server/app/api/deps.py @@ -5,8 +5,8 @@ from pydantic import ValidationError from sqlalchemy.orm import Session -from public_api import shared_schemas from public_api.permission_manager import PermissionManager +from public_api.shared_schemas import user as user_schemas from server.app import crud, models from server.app.core.config import settings from server.app.db.database import get_db @@ -22,10 +22,10 @@ def get_current_user( payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) - token_data = shared_schemas.TokenData(**payload) + token_data = user_schemas.TokenData(**payload) except (jwt.JWTError, ValidationError): raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, + status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) user = crud.user.get_by_username(db, username=token_data.username) @@ -51,16 +51,18 @@ def get_current_admin( ) return current_user + def get_permission_manager( - db: Session = Depends(get_db), - current_user: models.User = Depends(get_current_active_user) + db: Session = Depends(get_db), + current_user: models.User = Depends(get_current_active_user) ) -> PermissionManager: permissions = crud.user.get_user_permissions(db, current_user.id) return PermissionManager(permissions) + def has_permission(name: str, action: str): def permission_checker( - permission_manager: PermissionManager = Depends(get_permission_manager) + permission_manager: PermissionManager = Depends(get_permission_manager) ): if not permission_manager.has_permission(name, action): raise HTTPException( diff --git a/server/app/api/v1/endpoints/assets.py b/server/app/api/v1/endpoints/assets.py index 19fe6c6..1bfaf6e 100644 --- a/server/app/api/v1/endpoints/assets.py +++ b/server/app/api/v1/endpoints/assets.py @@ -31,7 +31,7 @@ def read_assets( current_user: models.User = Depends(deps.has_permission("asset", "read")) ): assets = crud.asset.get_multi_with_filter(db, skip=skip, limit=limit, filter_params=asset_filter) - total = crud.asset.count_with_filter(db, filter_params=asset_filter) + total = len(assets) return {"assets": assets, "total": total} @@ -172,7 +172,7 @@ def read_asset( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.has_permission("asset", "read")) ): - asset = crud.asset.get_with_maintenance(db, id=asset_id) + asset = crud.asset.get_with_maintenance(db, asset_id=asset_id) if asset is None: raise HTTPException(status_code=404, detail="Asset not found") return asset diff --git a/server/app/api/v1/endpoints/carriers.py b/server/app/api/v1/endpoints/carriers.py index b9b4a97..937af9c 100644 --- a/server/app/api/v1/endpoints/carriers.py +++ b/server/app/api/v1/endpoints/carriers.py @@ -28,8 +28,7 @@ def read_carriers( limit: int = 100, current_user: models.User = Depends(deps.get_current_active_user) ): - carriers = crud.carrier.get_multi(db, skip=skip, limit=limit) - return [Carrier.model_validate(carrier) for carrier in carriers] + return crud.carrier.get_multi(db, skip=skip, limit=limit) @router.get("/{carrier_id}", response_model=shared_schemas.Carrier) diff --git a/server/app/api/v1/endpoints/inventory.py b/server/app/api/v1/endpoints/inventory.py index 2ec4147..bbb0f93 100644 --- a/server/app/api/v1/endpoints/inventory.py +++ b/server/app/api/v1/endpoints/inventory.py @@ -84,7 +84,7 @@ def create_reorder_list( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.inventory.create_reorder_list(db, threshold=threshold) + return crud.inventory.get_low_stock_items(db, threshold=threshold) @router.get("/product_locations/{product_id}", response_model=List[shared_schemas.LocationWithInventory]) @@ -250,4 +250,4 @@ def adjust_inventory( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.inventory.adjust_quantity(db, id=id, adjustment=adjustment) + return crud.inventory.adjust_quantity(db, inventory_id=id, adjustment=adjustment) diff --git a/server/app/api/v1/endpoints/permissions.py b/server/app/api/v1/endpoints/permissions.py index 44e9014..7729cca 100644 --- a/server/app/api/v1/endpoints/permissions.py +++ b/server/app/api/v1/endpoints/permissions.py @@ -28,8 +28,7 @@ def read_permissions( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - permissions = crud.permission.get_multi(db, skip=skip, limit=limit) - return [shared_schemas.Permission.model_validate(perm) for perm in permissions] + return crud.permission.get_multi(db, skip=skip, limit=limit) @router.get("/{permission_id}", response_model=shared_schemas.Permission) diff --git a/server/app/api/v1/endpoints/pick_lists.py b/server/app/api/v1/endpoints/pick_lists.py index e04e0cc..a1e9a86 100644 --- a/server/app/api/v1/endpoints/pick_lists.py +++ b/server/app/api/v1/endpoints/pick_lists.py @@ -94,7 +94,7 @@ def start_pick_list( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.pick_list.start(db, pick_list_id=pick_list_id, user_id=current_user.user_id) + return crud.pick_list.start(db, pick_list_id=pick_list_id, user_id=current_user.id) @router.post("/{pick_list_id}/complete", response_model=shared_schemas.PickList) @@ -103,4 +103,4 @@ def complete_pick_list( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.pick_list.complete(db, pick_list_id=pick_list_id, user_id=current_user.user_id) + return crud.pick_list.complete(db, pick_list_id=pick_list_id, user_id=current_user.id) diff --git a/server/app/api/v1/endpoints/po_items.py b/server/app/api/v1/endpoints/po_items.py index 95e2ddd..f722cfb 100644 --- a/server/app/api/v1/endpoints/po_items.py +++ b/server/app/api/v1/endpoints/po_items.py @@ -18,8 +18,7 @@ def read_po_items( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - po_items = crud.po_item.get_multi(db, skip=skip, limit=limit) - return po_items + return crud.po_item.get_multi(db, skip=skip, limit=limit) @router.get("/by_product/{product_id}", response_model=List[shared_schemas.POItem]) @@ -30,8 +29,7 @@ def read_po_items_by_product( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - po_items = crud.po_item.get_by_product(db, product_id=product_id, skip=skip, limit=limit) - return po_items + return crud.po_item.get_by_product(db, product_id=product_id, skip=skip, limit=limit) @router.get("/pending_receipt", response_model=List[shared_schemas.POItem]) @@ -41,8 +39,7 @@ def read_pending_receipt_po_items( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - po_items = crud.po_item.get_pending_receipt(db, skip=skip, limit=limit) - return po_items + return crud.po_item.get_pending_receipt(db, skip=skip, limit=limit) @router.get("/{po_item_id}", response_model=shared_schemas.POItem) diff --git a/server/app/api/v1/endpoints/products.py b/server/app/api/v1/endpoints/products.py index 1b3a3f4..184cba8 100644 --- a/server/app/api/v1/endpoints/products.py +++ b/server/app/api/v1/endpoints/products.py @@ -5,7 +5,6 @@ from sqlalchemy.orm import Session, joinedload from public_api import shared_schemas -from public_api.shared_schemas import ProductWithCategoryAndInventory from .... import crud, models from ....api import deps from ....models import Product @@ -55,7 +54,7 @@ def get_product_by_barcode( product = crud.product.get(db, barcode=barcode_data.barcode, options=options) if product: - return ProductWithCategoryAndInventory.model_validate(product) + return product raise HTTPException(status_code=404, detail="Product not found") @@ -73,7 +72,7 @@ def read_product( product = crud.product.get(db, id=product_id, options=options) if product: - return ProductWithCategoryAndInventory.model_validate(product) + return product raise HTTPException(status_code=404, detail="Product not found") diff --git a/server/app/api/v1/endpoints/purchase_orders.py b/server/app/api/v1/endpoints/purchase_orders.py index 1ecb270..07aa370 100644 --- a/server/app/api/v1/endpoints/purchase_orders.py +++ b/server/app/api/v1/endpoints/purchase_orders.py @@ -4,8 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from .... import crud, models from public_api import shared_schemas +from .... import crud, models from ....api import deps router = APIRouter() @@ -78,4 +78,4 @@ def receive_purchase_order( po = crud.purchase_order.get(db, id=po_id) if po is None: raise HTTPException(status_code=404, detail="Purchase Order not found") - return crud.purchase_order.receive(db, db_obj=po, received_items=received_items) \ No newline at end of file + return crud.purchase_order.receive(db, db_obj=po, received_items=received_items) diff --git a/server/app/api/v1/endpoints/quality.py b/server/app/api/v1/endpoints/quality.py index 357fc09..5ca4515 100644 --- a/server/app/api/v1/endpoints/quality.py +++ b/server/app/api/v1/endpoints/quality.py @@ -5,8 +5,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body from sqlalchemy.orm import Session -from .... import crud, models from public_api import shared_schemas +from .... import crud, models from ....api import deps router = APIRouter() @@ -224,7 +224,7 @@ def add_comment_to_quality_check( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.quality_check.add_comment(db, check_id=check_id, comment=comment, user_id=current_user.user_id) + return crud.quality_check.add_comment(db, check_id=check_id, comment=comment, user_id=current_user.id) @router.get("/reports/defect_rate", response_model=List[shared_schemas.ProductDefectRate]) diff --git a/server/app/api/v1/endpoints/roles.py b/server/app/api/v1/endpoints/roles.py index abe603e..8de1b22 100644 --- a/server/app/api/v1/endpoints/roles.py +++ b/server/app/api/v1/endpoints/roles.py @@ -18,8 +18,7 @@ def read_roles( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - roles = crud.role.get_multi(db, skip=skip, limit=limit) - return [shared_schemas.Role.model_validate(role) for role in roles] + return crud.role.get_multi(db, skip=skip, limit=limit) @router.get("/{role_id}", response_model=shared_schemas.Role) @@ -31,7 +30,7 @@ def read_role( role = crud.role.get(db, id=role_id) if not role: raise HTTPException(status_code=404, detail="Role not found") - return shared_schemas.Role.model_validate(role) + return role @router.put("/{role_id}", response_model=shared_schemas.Role) @@ -45,7 +44,7 @@ def update_role( if not role: raise HTTPException(status_code=404, detail="Role not found") updated_role = crud.role.update(db, db_obj=role, obj_in=role_in) - return shared_schemas.Role.model_validate(updated_role) + return updated_role @router.delete("/{role_id}", response_model=shared_schemas.Role) @@ -58,4 +57,4 @@ def delete_role( if not role: raise HTTPException(status_code=404, detail="Role not found") deleted_role = crud.role.remove(db, id=role_id) - return shared_schemas.Role.model_validate(deleted_role) + return deleted_role diff --git a/server/app/api/v1/endpoints/search.py b/server/app/api/v1/endpoints/search.py index 8684aba..23ba4ea 100644 --- a/server/app/api/v1/endpoints/search.py +++ b/server/app/api/v1/endpoints/search.py @@ -1,7 +1,3 @@ -# /server/app/api/v1/endpoints/search.py - -from typing import List, Optional - from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session @@ -12,15 +8,15 @@ router = APIRouter() -@router.get("/products", response_model=List[shared_schemas.Product]) +@router.get("/products", response_model=list[shared_schemas.Product]) def search_products( - q: Optional[str] = Query(None, description="Search query string"), - category_id: Optional[int] = Query(None), - min_price: Optional[float] = Query(None), - max_price: Optional[float] = Query(None), - in_stock: Optional[bool] = Query(None), - sort_by: Optional[str] = Query(None), - sort_order: Optional[str] = Query("asc"), + q: str | None = Query(None, description="Search query string"), + category_id: int | None = Query(None), + min_price: float | None = Query(None), + max_price: float | None = Query(None), + in_stock: bool | None = Query(None), + sort_by: str | None = Query(None), + sort_order: str | None = Query("asc"), db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): @@ -31,16 +27,16 @@ def search_products( ) -@router.get("/orders", response_model=List[shared_schemas.Order]) +@router.get("/orders", response_model=list[shared_schemas.Order]) def search_orders( - q: Optional[str] = Query(None, description="Search query string"), - status: Optional[str] = Query(None), - min_total: Optional[float] = Query(None), - max_total: Optional[float] = Query(None), - start_date: Optional[int] = Query(None), - end_date: Optional[int] = Query(None), - sort_by: Optional[str] = Query(None), - sort_order: Optional[str] = Query("asc"), + q: str | None = Query(None, description="Search query string"), + status: str | None = Query(None), + min_total: float | None = Query(None), + max_total: float | None = Query(None), + start_date: int | None = Query(None), + end_date: int | None = Query(None), + sort_by: str | None = Query(None), + sort_order: str | None = Query("asc"), db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): diff --git a/server/app/api/v1/endpoints/tasks.py b/server/app/api/v1/endpoints/tasks.py index 5fb2fbf..00f32d2 100644 --- a/server/app/api/v1/endpoints/tasks.py +++ b/server/app/api/v1/endpoints/tasks.py @@ -133,7 +133,7 @@ def add_task_comment( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - return crud.task.add_comment(db, task_id=task_id, comment=comment, user_id=current_user.user_id) + return crud.task.add_comment(db, task_id=task_id, comment=comment, user_id=current_user.id) @router.get("/{task_id}/comments", response_model=List[shared_schemas.TaskComment]) diff --git a/server/app/api/v1/endpoints/users.py b/server/app/api/v1/endpoints/users.py index d99ead3..6827128 100644 --- a/server/app/api/v1/endpoints/users.py +++ b/server/app/api/v1/endpoints/users.py @@ -1,76 +1,103 @@ # /server/app/api/v1/endpoints/users.py -from typing import List +from datetime import timedelta import pyotp -from fastapi import APIRouter, Depends, HTTPException, Body, Query +from fastapi import APIRouter, Depends, HTTPException, Body, Query, status from fastapi.security import OAuth2PasswordRequestForm +from jose import JWTError, jwt from sqlalchemy.orm import Session -from public_api import shared_schemas -from ....core import security -from public_api.shared_schemas import UserFilter, UserWithPermissions, \ - UserPermissionUpdate -from .... import crud, models -from ....api import deps -from ....core.email import send_reset_password_email -from server.app.core.security import get_password_hash +from public_api.shared_schemas import user as user_schemas, RefreshTokenRequest +from server.app import crud, models +from server.app.api import deps +from server.app.core import security +from server.app.core.config import settings +from server.app.core.email import send_reset_password_email router = APIRouter() -@router.post("/login", response_model=shared_schemas.Token) +@router.post("/login", response_model=user_schemas.Token) def login( db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends() ): user = crud.user.authenticate(db, email=form_data.username, password=form_data.password) if not user: - raise HTTPException(status_code=400, detail="Incorrect email or password") - elif not crud.user.is_active(user): - raise HTTPException(status_code=400, detail="Inactive user") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Incorrect email or password") + if not crud.user.is_active(user): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user") if user.two_factor_auth_enabled: - # Return a special token to indicate 2FA is required - return {"access_token": "2FA_REQUIRED", "token_type": "bearer"} + return user_schemas.Token(access_token="2FA_REQUIRED", token_type="bearer", refresh_token="", expires_in=0) - # If 2FA is not enabled, proceed with normal login return create_token_for_user(user) -@router.post("/login/2fa", response_model=shared_schemas.Token) +@router.post("/login/2fa", response_model=user_schemas.Token) def login_2fa( db: Session = Depends(deps.get_db), - login_data: shared_schemas.TwoFactorLogin = Body(...) + login_data: user_schemas.TwoFactorLogin = Body(...) ): user = crud.user.authenticate(db, email=login_data.username, password=login_data.password) if not user or not user.two_factor_auth_enabled: - raise HTTPException(status_code=400, detail="Invalid credentials or 2FA not enabled") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid credentials or 2FA not enabled") totp = pyotp.TOTP(user.two_factor_auth_secret) if not totp.verify(login_data.two_factor_code): - raise HTTPException(status_code=400, detail="Invalid 2FA code") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code") return create_token_for_user(user) -def create_token_for_user(user: models.User) -> shared_schemas.Token: - access_token = security.create_access_token(user.username) - return shared_schemas.Token(access_token=access_token, token_type="bearer") +def create_token_for_user(user: models.User) -> user_schemas.Token: + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = security.create_access_token( + user.username, expires_delta=access_token_expires + ) + refresh_token = security.create_refresh_token(user.username) + return user_schemas.Token( + access_token=access_token, + refresh_token=refresh_token, + token_type="bearer", + expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES # * 60 + ) + + +@router.post("/refresh-token", response_model=user_schemas.Token) +def refresh_token( + refresh_data: RefreshTokenRequest, + db: Session = Depends(deps.get_db) +): + try: + payload = jwt.decode( + refresh_data.refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=400, detail="Invalid refresh token") + except JWTError: + raise HTTPException(status_code=400, detail="Invalid refresh token") + + user = crud.user.get_by_username(db, username=username) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return create_token_for_user(user) -@router.post("/register", response_model=shared_schemas.User) +@router.post("/register", response_model=user_schemas.UserSanitized) def register_user( - user: shared_schemas.UserCreate, + user: user_schemas.UserCreate, db: Session = Depends(deps.get_db) ): - db_user = crud.user.get(db, email=user.email) + db_user = crud.user.get_by_email(db, email=user.email) if db_user: raise HTTPException(status_code=400, detail="Email already registered") - new_user = crud.user.create(db=db, obj_in=user) - return shared_schemas.User.model_validate(new_user) + return crud.user.create(db=db, obj_in=user) -@router.post("/reset_password", response_model=shared_schemas.Message) +@router.post("/reset_password", response_model=user_schemas.Message) def reset_password( email: str = Body(..., embed=True), db: Session = Depends(deps.get_db) @@ -89,7 +116,7 @@ def reset_password( return {"message": result} -@router.post("/change_password", response_model=shared_schemas.Message) +@router.post("/change_password", response_model=user_schemas.Message) def change_user_password( current_password: str = Body(...), new_password: str = Body(...), @@ -103,53 +130,51 @@ def change_user_password( return {"message": "Password updated successfully"} -@router.put("/me", response_model=shared_schemas.User) +@router.put("/me", response_model=user_schemas.UserSanitized) def update_user_me( - user_in: shared_schemas.UserUpdate, + user_in: user_schemas.UserUpdate, current_user: models.User = Depends(deps.get_current_active_user), db: Session = Depends(deps.get_db) ): user = crud.user.update(db, db_obj=current_user, obj_in=user_in) - return shared_schemas.User.model_validate(user) + return user_schemas.UserSanitized.model_validate(user) -@router.get("/me", response_model=shared_schemas.UserSanitizedWithRole) +@router.get("/me", response_model=user_schemas.UserSanitized) def read_user_me( current_user: models.User = Depends(deps.get_current_active_user), ): - return current_user + return user_schemas.UserSanitized.model_validate(current_user) -@router.get("/permissions", response_model=shared_schemas.AllPermissions) +@router.get("/permissions", response_model=user_schemas.AllPermissions) def get_all_permissions( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_admin) ): - permissions = crud.user.get_all_permissions(db) - return shared_schemas.AllPermissions(permissions=permissions) + return {"permissions": crud.user.get_all_permissions(db)} -@router.get("/my_permissions", response_model=shared_schemas.AllPermissions) +@router.get("/my_permissions", response_model=user_schemas.AllPermissions) def get_my_permissions( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_active_user) ): - permissions = crud.user.get_user_with_permissions(db, current_user.id).permissions - return shared_schemas.AllPermissions(permissions=permissions) + return {"permissions": crud.user.get_user_with_permissions(db, current_user.id).permissions} -@router.get("/roles", response_model=shared_schemas.AllRoles) +@router.get("/roles", response_model=user_schemas.AllRoles) def get_all_roles( db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_admin) ): roles = crud.user.get_all_roles(db) - return shared_schemas.AllRoles(roles=roles) + return user_schemas.AllRoles(roles=roles) -@router.get("/", response_model=List[shared_schemas.UserSanitizedWithRole]) +@router.get("/", response_model=list[user_schemas.UserSanitized]) def read_users( - filter_params: UserFilter = Depends(), + filter_params: user_schemas.UserFilter = Depends(), skip: int = Query(0), limit: int = Query(100), db: Session = Depends(deps.get_db), @@ -161,26 +186,22 @@ def read_users( skip=skip, limit=limit ) - return [shared_schemas.UserSanitizedWithRole.model_validate(user) for user in users] + return users -@router.post("/", response_model=shared_schemas.UserSanitizedWithRole) +@router.post("/", response_model=user_schemas.UserSanitized) def create_user( - user: shared_schemas.UserCreate, + user: user_schemas.UserCreate, db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_admin) ): db_user = crud.user.get_by_email(db, email=user.email) if db_user: raise HTTPException(status_code=400, detail="Email already registered") + return crud.user.create(db=db, obj_in=user) - user.password = get_password_hash(user.password) - - new_user = crud.user.create(db=db, obj_in=user) - return shared_schemas.UserSanitizedWithRole.model_validate(new_user) - -@router.get("/{user_id}/permissions", response_model=UserWithPermissions) +@router.get("/{user_id}/permissions", response_model=user_schemas.UserWithPermissions) def get_user_permissions( user_id: int, db: Session = Depends(deps.get_db), @@ -188,22 +209,20 @@ def get_user_permissions( ): if not crud.user.is_admin(current_user) and current_user.id != user_id: raise HTTPException(status_code=403, detail="Not enough permissions") - user_with_permissions = crud.user.get_user_with_permissions(db, user_id) - return UserWithPermissions.model_validate(user_with_permissions) + return crud.user.get_user_with_permissions(db, user_id) -@router.put("/{user_id}/permissions", response_model=UserWithPermissions) +@router.put("/{user_id}/permissions", response_model=user_schemas.UserWithPermissions) def update_user_permissions( user_id: int, - permission_update: UserPermissionUpdate, + permission_update: user_schemas.UserPermissionUpdate, db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_admin), ): - updated_user = crud.user.update_user_permissions(db, user_id=user_id, permission_ids=permission_update.permissions) - return UserWithPermissions.model_validate(updated_user) + return crud.user.update_user_permissions(db, user_id=user_id, permission_ids=permission_update.permissions) -@router.get("/{user_id}", response_model=shared_schemas.UserSanitizedWithRole) +@router.get("/{user_id}", response_model=user_schemas.UserSanitized) def read_user( user_id: int, db: Session = Depends(deps.get_db), @@ -213,16 +232,16 @@ def read_user( if not user: raise HTTPException(status_code=404, detail="User not found") if user.id == current_user.id: - return shared_schemas.UserSanitizedWithRole.model_validate(user) + return user_schemas.UserSanitized.model_validate(user) if not crud.user.is_admin(current_user): raise HTTPException(status_code=400, detail="Not enough permissions") - return shared_schemas.UserSanitizedWithRole.model_validate(user) + return user_schemas.UserSanitized.model_validate(user) -@router.put("/{user_id}", response_model=shared_schemas.User) +@router.put("/{user_id}", response_model=user_schemas.UserSanitized) def update_user( user_id: int, - user_in: shared_schemas.UserUpdate, + user_in: user_schemas.UserUpdate, db: Session = Depends(deps.get_db), current_user: models.User = Depends(deps.get_current_admin) ): @@ -230,15 +249,11 @@ def update_user( if not user: raise HTTPException(status_code=404, detail="User not found") - update_data = user_in.model_dump(exclude_unset=True) - if 'password' in update_data: - update_data['password'] = get_password_hash(update_data['password']) - - updated_user = crud.user.update(db, db_obj=user, obj_in=update_data) - return shared_schemas.User.model_validate(updated_user) + updated_user = crud.user.update(db, db_obj=user, obj_in=user_in) + return user_schemas.UserSanitized.model_validate(updated_user) -@router.delete("/{user_id}", response_model=shared_schemas.User) +@router.delete("/{user_id}", response_model=user_schemas.UserSanitized) def delete_user( user_id: int, db: Session = Depends(deps.get_db), @@ -247,5 +262,4 @@ def delete_user( user = crud.user.get(db, id=user_id) if not user: raise HTTPException(status_code=404, detail="User not found") - user = crud.user.remove(db, id=user_id) - return shared_schemas.User.model_validate(user) + return crud.user.remove(db, id=user_id) diff --git a/server/app/core/config.py b/server/app/core/config.py index 3e106cd..2c90b73 100644 --- a/server/app/core/config.py +++ b/server/app/core/config.py @@ -9,6 +9,7 @@ class Settings(BaseSettings): SECRET_KEY: str = "SECRET" ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 DATABASE_URL: str = "sqlite:///./nexusware.db" # SMTP Configuration diff --git a/server/app/core/security.py b/server/app/core/security.py index 91426b1..08a4388 100644 --- a/server/app/core/security.py +++ b/server/app/core/security.py @@ -37,3 +37,15 @@ def generate_password_reset_token(email: str) -> str: def get_password_hash(password: str) -> str: return pwd_context.hash(password) + + +def create_refresh_token( + subject: Union[str, Any], expires_delta: timedelta = None +) -> str: + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + to_encode = {"exp": expire, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + return encoded_jwt diff --git a/server/app/crud/asset.py b/server/app/crud/asset.py index 7e195cb..99fec0e 100644 --- a/server/app/crud/asset.py +++ b/server/app/crud/asset.py @@ -1,9 +1,7 @@ from datetime import datetime -from typing import Optional, List from sqlalchemy.orm import Session, joinedload -from server.app.models import Asset, AssetMaintenance, Location from public_api.shared_schemas import ( Asset as AssetSchema, Location as LocationSchema, @@ -12,20 +10,21 @@ AssetUpdate, AssetFilter ) +from server.app.models import Asset, AssetMaintenance, Location from .base import CRUDBase class CRUDAsset(CRUDBase[Asset, AssetCreate, AssetUpdate]): - def get_due_for_maintenance(self, db: Session) -> List[AssetWithMaintenanceSchema]: + def get_due_for_maintenance(self, db: Session) -> list[AssetSchema]: current_timestamp = int(datetime.now().timestamp()) assets = db.query(self.model).join(AssetMaintenance).filter( AssetMaintenance.scheduled_date <= current_timestamp, AssetMaintenance.completed_date.is_(None) ).all() - return [AssetWithMaintenanceSchema.model_validate(asset) for asset in assets] + return [AssetSchema.model_validate(asset) for asset in assets] - def transfer(self, db: Session, asset_id: int, new_location_id: int) -> Optional[AssetSchema]: + def transfer(self, db: Session, asset_id: int, new_location_id: int) -> Asset | None: current_asset = db.query(self.model).filter(self.model.id == asset_id).first() if current_asset: location = db.query(Location).filter(Location.id == new_location_id).first() @@ -36,24 +35,21 @@ def transfer(self, db: Session, asset_id: int, new_location_id: int) -> Optional return AssetSchema.model_validate(current_asset) return None - def get_asset_location(self, db: Session, asset_id: int) -> Optional[LocationSchema]: + def get_asset_location(self, db: Session, asset_id: int) -> LocationSchema | None: asset = db.query(self.model).filter(self.model.id == asset_id).first() if asset and asset.location: location = db.query(Location).filter(Location.name == asset.location).first() - if location: - return LocationSchema.model_validate(location) + return LocationSchema.model_validate(location) if location else None return None - def get_with_maintenance(self, db: Session, id: int) -> Optional[AssetWithMaintenanceSchema]: - asset = db.query(self.model).filter(self.model.id == id).options( + def get_with_maintenance(self, db: Session, asset_id: int) -> AssetWithMaintenanceSchema | None: + asset = db.query(self.model).filter(self.model.id == asset_id).options( joinedload(self.model.maintenance_records) ).first() - if asset: - return AssetWithMaintenanceSchema.model_validate(asset) - return None + return AssetWithMaintenanceSchema.model_validate(asset) if asset else None def get_multi_with_filter(self, db: Session, *, - skip: int = 0, limit: int = 100, filter_params: AssetFilter) -> List[AssetSchema]: + skip: int = 0, limit: int = 100, filter_params: AssetFilter) -> list[AssetSchema]: query = db.query(self.model) if filter_params.asset_type: query = query.filter(self.model.asset_type == filter_params.asset_type) @@ -69,23 +65,10 @@ def get_multi_with_filter(self, db: Session, *, assets = query.offset(skip).limit(limit).all() return [AssetSchema.model_validate(asset) for asset in assets] - def count_with_filter(self, db: Session, *, filter_params: AssetFilter) -> int: - query = db.query(self.model) - if filter_params.asset_type: - query = query.filter(self.model.asset_type == filter_params.asset_type) - if filter_params.status: - query = query.filter(self.model.status == filter_params.status) - if filter_params.purchase_date_from: - query = query.filter(self.model.purchase_date >= filter_params.purchase_date_from) - if filter_params.purchase_date_to: - query = query.filter(self.model.purchase_date <= filter_params.purchase_date_to) - - return query.count() - - def get_all_types(self, db: Session) -> List[str]: + def get_all_types(self, db: Session) -> list[str]: return [asset_type for (asset_type,) in db.query(self.model.asset_type).distinct().all()] - def get_all_statuses(self, db: Session) -> List[str]: + def get_all_statuses(self, db: Session) -> list[str]: return [status for (status,) in db.query(self.model.status).distinct().all()] diff --git a/server/app/crud/audit.py b/server/app/crud/audit.py index ac2e110..87f8c36 100644 --- a/server/app/crud/audit.py +++ b/server/app/crud/audit.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import func, desc from sqlalchemy.orm import Session @@ -30,7 +28,7 @@ def get_multi_with_filter(self, db: Session, *, audit_logs = query.order_by(desc(AuditLog.timestamp)).offset(skip).limit(limit).all() return [AuditLogSchema.model_validate(audit_log) for audit_log in audit_logs] - def get_summary(self, db: Session, date_from: Optional[int], date_to: Optional[int]) -> AuditSummary: + def get_summary(self, db: Session, date_from: int | None, date_to: int | None) -> AuditSummary: query = db.query(self.model) if date_from: query = query.filter(AuditLog.timestamp >= date_from) diff --git a/server/app/crud/base.py b/server/app/crud/base.py index 0d747c6..6a44c23 100644 --- a/server/app/crud/base.py +++ b/server/app/crud/base.py @@ -1,5 +1,4 @@ -# /server/app/crud/base.py -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing import Generic, Type, TypeVar from fastapi.encoders import jsonable_encoder from pydantic import BaseModel @@ -10,55 +9,88 @@ ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +GetSchemaType = TypeVar("GetSchemaType", bound=BaseModel) class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): - # First in [] is the model type (imported from models), - # second is the create schema type, - # and third is the update schema type. def __init__(self, model: Type[ModelType]): self.model = model - def get(self, db: Session, id: Any, *, options: Optional[list] = None) -> Optional[ModelType]: + def get( + self, + db: Session, + id: any, + *, + options: list | None = None, + return_schema: Type[GetSchemaType] | None = None + ) -> ModelType | GetSchemaType | None: query = db.query(self.model) if options: - for option in options: - query = query.options(option) - return query.filter(self.model.id == id).first() + query = query.options(*options) + db_obj = query.filter(self.model.id == id).first() + if return_schema and db_obj: + return return_schema.model_validate(db_obj) + return db_obj - def get_multi(self, db: Session, *, skip: int = 0, limit: int = 100) -> List[ModelType]: - return db.query(self.model).offset(skip).limit(limit).all() + def get_multi( + self, + db: Session, + *, + skip: int = 0, + limit: int = 100, + return_schema: Type[GetSchemaType] | None = None + ) -> list[ModelType] | list[GetSchemaType]: + db_objs = db.query(self.model).offset(skip).limit(limit).all() + if return_schema: + return [return_schema.model_validate(obj) for obj in db_objs] + return db_objs - def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: + def create( + self, + db: Session, + *, + obj_in: CreateSchemaType, + return_schema: Type[GetSchemaType] | None = None + ) -> ModelType | GetSchemaType: obj_in_data = obj_in.model_dump() db_obj = self.model(**obj_in_data) db.add(db_obj) db.commit() db.refresh(db_obj) + if return_schema: + return return_schema.model_validate(db_obj) return db_obj - def update(self, - db: Session, - *, - db_obj: ModelType, - obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType: - # WARNING: if date/datetime column in db_obj - encoding will change it to str, and will subsequently raise - # an error when trying to commit the changes. + def update( + self, + db: Session, + *, + db_obj: ModelType, + obj_in: UpdateSchemaType | dict[str, any], + return_schema: Type[GetSchemaType] | None = None + ) -> ModelType | GetSchemaType: obj_data = jsonable_encoder(db_obj) - if isinstance(obj_in, dict): - update_data = obj_in - else: - update_data = obj_in.model_dump(exclude_unset=True) + update_data = obj_in if isinstance(obj_in, dict) else obj_in.model_dump(exclude_unset=True) for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) db.add(db_obj) db.commit() db.refresh(db_obj) + if return_schema: + return return_schema.model_validate(db_obj) return db_obj - def remove(self, db: Session, *, id: int) -> ModelType: + def remove( + self, + db: Session, + *, + id: int, + return_schema: Type[GetSchemaType] | None = None + ) -> ModelType | GetSchemaType: obj = db.query(self.model).get(id) db.delete(obj) db.commit() + if return_schema: + return return_schema.model_validate(obj) return obj diff --git a/server/app/crud/dock_appointment.py b/server/app/crud/dock_appointment.py index 3226b29..ca98e73 100644 --- a/server/app/crud/dock_appointment.py +++ b/server/app/crud/dock_appointment.py @@ -1,21 +1,20 @@ -from typing import Optional, List from datetime import timedelta from sqlalchemy.orm import Session -from server.app.models import DockAppointment from public_api.shared_schemas import ( DockAppointment as DockAppointmentSchema, DockAppointmentCreate, DockAppointmentUpdate, DockAppointmentFilter, AppointmentConflict ) +from server.app.models import DockAppointment from .base import CRUDBase class CRUDDockAppointment(CRUDBase[DockAppointment, DockAppointmentCreate, DockAppointmentUpdate]): def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, - filter_params: DockAppointmentFilter) -> List[DockAppointmentSchema]: + filter_params: DockAppointmentFilter) -> list[DockAppointmentSchema]: query = db.query(self.model) if filter_params.yard_location_id: query = query.filter(DockAppointment.yard_location_id == filter_params.yard_location_id) @@ -34,7 +33,7 @@ def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, def check_conflicts(self, db: Session, appointment: DockAppointmentCreate, - exclude_id: Optional[int] = None) -> List[AppointmentConflict]: + exclude_id: int | None = None) -> list[AppointmentConflict]: conflicts = [] appointment_duration = timedelta(hours=1) # Assume 1-hour appointments query = db.query(DockAppointment).filter( diff --git a/server/app/crud/inventory.py b/server/app/crud/inventory.py index 3b44adb..6460a7c 100644 --- a/server/app/crud/inventory.py +++ b/server/app/crud/inventory.py @@ -1,7 +1,6 @@ import time from collections import defaultdict from datetime import timedelta, datetime -from typing import Optional, List, Dict, Tuple import numpy as np from fastapi import HTTPException @@ -39,12 +38,11 @@ def get_multi_with_products( skip: int = 0, limit: int = 100, filter_params: InventoryFilter - ) -> List[InventoryWithDetails]: + ) -> list[InventoryWithDetails]: query = db.query(Inventory).options( joinedload(Inventory.product), joinedload(Inventory.location) ) - if filter_params.product_id: query = query.filter(Inventory.product_id == filter_params.product_id) if filter_params.location_id: @@ -63,7 +61,7 @@ def get_multi_with_products( return [InventoryWithDetails.model_validate(item) for item in items] def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, - filter_params: InventoryFilter) -> list[InventorySchema]: + filter_params: InventoryFilter) -> list[Inventory]: query = db.query(Inventory) if filter_params.product_id: query = query.filter(Inventory.product_id == filter_params.product_id) @@ -77,10 +75,11 @@ def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, query = query.filter(Inventory.quantity >= filter_params.quantity_min) if filter_params.quantity_max is not None: query = query.filter(Inventory.quantity <= filter_params.quantity_max) + return [InventorySchema.model_validate(x) for x in query.offset(skip).limit(limit).all()] - def adjust_quantity(self, db: Session, id: int, adjustment: InventoryAdjustmentSchema) -> InventorySchema: - current_inventory = self.get(db, id=id) + def adjust_quantity(self, db: Session, inventory_id: int, adjustment: InventoryAdjustmentSchema) -> InventorySchema: + current_inventory = self.get(db, id=inventory_id) if not current_inventory: raise HTTPException(status_code=404, detail="Inventory item not found") current_inventory.quantity += adjustment.quantity_change @@ -99,7 +98,7 @@ def adjust_quantity(self, db: Session, id: int, adjustment: InventoryAdjustmentS db.commit() db.refresh(current_inventory) - return InventorySchema.model_validate(current_inventory) + return current_inventory def transfer(self, db: Session, transfer: InventoryTransfer) -> InventorySchema: from_inventory = db.query(Inventory).filter( @@ -192,8 +191,8 @@ def batch_update(self, db: Session, updates: list[InventoryUpdate]) -> list[Inve db.commit() return [InventorySchema.model_validate(inventory) for inventory in updated_items] - def get_movement_history(self, db: Session, product_id: int, start_date: Optional[int], - end_date: Optional[int]) -> list[InventoryMovementSchema]: + def get_movement_history(self, db: Session, product_id: int, start_date: int | None, + end_date: int | None) -> list[InventoryMovementSchema]: query = db.query(InventoryMovement).filter(InventoryMovement.product_id == product_id) if start_date: query = query.filter(InventoryMovement.timestamp >= start_date) @@ -411,13 +410,13 @@ def get_storage_utilization(self, db: Session) -> StorageUtilization: def advanced_search( self, db: Session, - q: Optional[str] = None, - category_id: Optional[int] = None, - min_price: Optional[float] = None, - max_price: Optional[float] = None, - in_stock: Optional[bool] = None, - sort_by: Optional[str] = None, - sort_order: Optional[str] = "asc" + q: str | None = None, + category_id: int | None = None, + min_price: float | None = None, + max_price: float | None = None, + in_stock: bool | None = None, + sort_by: str | None = None, + sort_order: str | None = "asc" ) -> list[ProductSchema]: query = db.query(Product) @@ -460,7 +459,7 @@ def advanced_search( products = query.all() return [ProductSchema.model_validate(product) for product in products] - def get_forecast_for_product_id(self, db: Session, product_id: int) -> Dict: + def get_forecast_for_product_id(self, db: Session, product_id: int) -> dict: history = db.query(InventoryMovement).filter(InventoryMovement.product_id == product_id).order_by( InventoryMovement.timestamp).all() @@ -491,7 +490,7 @@ def get_forecast_for_product_id(self, db: Session, product_id: int) -> Dict: return {"forecast": forecast} - def get_reorder_suggestions(self, db: Session) -> List[Dict]: + def get_reorder_suggestions(self, db: Session) -> list[dict]: products = db.query(Product).options(joinedload(Product.inventory_items)).all() suggestions = [] for product in products: @@ -531,7 +530,7 @@ def get_reorder_suggestions(self, db: Session) -> List[Dict]: return suggestions def get_inventory_trend_with_prediction(self, db: Session, days_past: int = 5, days_future: int = 5) \ - -> Tuple[List[InventoryTrendItem], List[InventoryTrendItem]]: + -> (list[InventoryTrendItem], list[InventoryTrendItem]): end_timestamp = int(time.time()) start_timestamp = end_timestamp - (days_past * 86400) diff --git a/server/app/crud/location.py b/server/app/crud/location.py index 92fca5b..a9de256 100644 --- a/server/app/crud/location.py +++ b/server/app/crud/location.py @@ -1,15 +1,14 @@ # /server/app/crud/location.py -from typing import Optional from sqlalchemy.orm import Session, joinedload -from server.app.models import ( - Location -) from public_api.shared_schemas import ( LocationCreate, LocationUpdate, LocationWithInventory as LocationWithInventorySchema, LocationFilter ) +from server.app.models import ( + Location +) from .base import CRUDBase @@ -36,7 +35,7 @@ def get_multi_with_inventory( locations = query.offset(skip).limit(limit).all() return [LocationWithInventorySchema.model_validate(location) for location in locations] - def get_with_inventory(self, db: Session, id: int) -> Optional[LocationWithInventorySchema]: + def get_with_inventory(self, db: Session, id: int) -> LocationWithInventorySchema | None: location = db.query(Location).filter(Location.id == id).options( joinedload(Location.inventory_items)).first() return LocationWithInventorySchema.model_validate(location) if location else None diff --git a/server/app/crud/order.py b/server/app/crud/order.py index 0dbadfe..11f0cc7 100644 --- a/server/app/crud/order.py +++ b/server/app/crud/order.py @@ -1,5 +1,3 @@ -from typing import Optional, List - from fastapi import HTTPException from sqlalchemy import func from sqlalchemy.orm import Session, joinedload @@ -50,7 +48,38 @@ def get_multi_with_details(self, db: Session, *, orders = query.offset(skip).limit(limit).all() return [OrderWithDetailsSchema.model_validate(x) for x in orders] - def get_summary(self, db: Session, date_from: Optional[int], date_to: Optional[int]) -> OrderSummary: + def advanced_search(self, db: Session, *, q: str = None, status: str = None, min_total: float = None, + max_total: float = None, start_date: int = None, end_date: int = None, + sort_by: str = None, sort_order: str = "asc") -> list[OrderSchema]: + query = db.query(self.model).options(joinedload(Order.customer)) + if q: + query = query.filter(Order.customer_id.ilike(f"%{q}%")) + if status: + query = query.filter(Order.status == status) + if min_total: + query = query.filter(Order.total_amount >= min_total) + if max_total: + query = query.filter(Order.total_amount <= max_total) + if start_date: + query = query.filter(Order.order_date >= start_date) + if end_date: + query = query.filter(Order.order_date <= end_date) + if sort_by: + if sort_by == "total_amount": + if sort_order == "asc": + query = query.order_by(Order.total_amount.asc()) + else: + query = query.order_by(Order.total_amount.desc()) + elif sort_by == "order_date": + if sort_order == "asc": + query = query.order_by(Order.order_date.asc()) + else: + query = query.order_by(Order.order_date.desc()) + + orders = query.all() + return [OrderSchema.model_validate(order) for order in orders] + + def get_summary(self, db: Session, date_from: int | None, date_to: int | None) -> OrderSummary: query = db.query(func.count(Order.id).label("total_orders"), func.sum(Order.total_amount).label("total_revenue")) if date_from: @@ -156,7 +185,7 @@ def get_processing_times(self, db: Session, *, start_date: int, end_date: int) - if processing_times.max_time else 0 ) - def update_items(self, db: Session, order_id: int, items: List[OrderItemUpdate]) -> Order: + def update_items(self, db: Session, order_id: int, items: list[OrderItemUpdate]) -> Order: db_order = db.query(Order).filter(Order.id == order_id).first() if not db_order: raise HTTPException(status_code=404, detail="Order not found") diff --git a/server/app/crud/permission.py b/server/app/crud/permission.py index 1487cac..391818b 100644 --- a/server/app/crud/permission.py +++ b/server/app/crud/permission.py @@ -1,16 +1,15 @@ # /server/app/crud/permission.py -from typing import Optional from sqlalchemy.orm import Session -from server.app.models import Permission from public_api.shared_schemas import PermissionCreate, PermissionUpdate, \ Permission as PermissionSchema +from server.app.models import Permission from .base import CRUDBase class CRUDPermission(CRUDBase[Permission, PermissionCreate, PermissionUpdate]): - def get_by_name(self, db: Session, *, name: str) -> Optional[PermissionSchema]: + def get_by_name(self, db: Session, *, name: str) -> PermissionSchema | None: current_permission = db.query(Permission).filter(Permission.permission_name == name).first() return PermissionSchema.model_validate(current_permission) if current_permission else None diff --git a/server/app/crud/product.py b/server/app/crud/product.py index 66197fa..9d8a659 100644 --- a/server/app/crud/product.py +++ b/server/app/crud/product.py @@ -1,6 +1,6 @@ # /server/app/crud/product.py -from sqlalchemy import func +from sqlalchemy import func, or_ from sqlalchemy.orm import Session, joinedload from public_api.shared_schemas import ( @@ -8,7 +8,7 @@ ProductFilter, ProductWithCategoryAndInventory ) from server.app.models import ( - Product + Product, Inventory ) from .base import CRUDBase @@ -39,5 +39,61 @@ def get_max_id(self, db: Session) -> int: max_id = db.query(func.max(Product.id)).scalar() return max_id if max_id is not None else 0 + def advanced_search( + self, + db: Session, + *, + q: str | None = None, + category_id: int | None = None, + min_price: float | None = None, + max_price: float | None = None, + in_stock: bool | None = None, + sort_by: str | None = None, + sort_order: str | None = "asc", + skip: int = 0, + limit: int = 100 + ) -> list[ProductWithCategoryAndInventory]: + query = db.query(Product).options( + joinedload(Product.category), + joinedload(Product.inventory_items) + ) + + # Apply filters + if q: + query = query.filter( + or_( + Product.name.ilike(f"%{q}%"), + Product.description.ilike(f"%{q}%"), + Product.sku.ilike(f"%{q}%") + ) + ) + + if category_id is not None: + query = query.filter(Product.category_id == category_id) + + if min_price is not None: + query = query.filter(Product.price >= min_price) + + if max_price is not None: + query = query.filter(Product.price <= max_price) + + if in_stock is not None: + if in_stock: + query = query.join(Inventory).filter(Inventory.quantity > 0) + else: + query = query.outerjoin(Inventory).group_by(Product.id).having(func.sum(Inventory.quantity) == 0) + + if sort_by: + sort_column = getattr(Product, sort_by, None) + if sort_column is not None: + if sort_order and sort_order.lower() == "desc": + sort_column = sort_column.desc() + query = query.order_by(sort_column) + + # Apply pagination + products = query.offset(skip).limit(limit).all() + + return [ProductWithCategoryAndInventory.model_validate(product) for product in products] + product = CRUDProduct(Product) diff --git a/server/app/crud/quality.py b/server/app/crud/quality.py index 2ba2012..4d468f0 100644 --- a/server/app/crud/quality.py +++ b/server/app/crud/quality.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import func, case from sqlalchemy.orm import Session @@ -35,7 +33,7 @@ def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, quality_checks = query.offset(skip).limit(limit).all() return [QualityCheckSchema.model_validate(check) for check in quality_checks] - def get_metrics(self, db: Session, date_from: Optional[int], date_to: Optional[int]) -> QualityMetrics: + def get_metrics(self, db: Session, date_from: int | None, date_to: int | None) -> QualityMetrics: query = db.query( func.count(QualityCheck.id).label("total_checks"), func.sum(case((QualityCheck.result == "pass", 1), else_=0)).label("passes"), @@ -57,7 +55,7 @@ def get_metrics(self, db: Session, date_from: Optional[int], date_to: Optional[i fail_rate=fails / total if total > 0 else 0 ) - def get_summary(self, db: Session, date_from: Optional[int], date_to: Optional[int]) -> dict[str, int]: + def get_summary(self, db: Session, date_from: int | None, date_to: int | None) -> dict[str, int]: query = db.query(QualityCheck.result, func.count(QualityCheck.id)) if date_from: query = query.filter(QualityCheck.check_date >= date_from) @@ -92,7 +90,7 @@ def add_comment(self, db: Session, *, check_id: int, comment: QualityCheckCommen return QualityCheckCommentSchema.model_validate(db_comment) def get_product_defect_rates(self, db: Session, - date_from: Optional[int], date_to: Optional[int]) -> list[ProductDefectRate]: + date_from: int | None, date_to: int | None) -> list[ProductDefectRate]: query = db.query( Product.id, Product.name.label("product_name"), diff --git a/server/app/crud/role.py b/server/app/crud/role.py index 98996c4..4861ff1 100644 --- a/server/app/crud/role.py +++ b/server/app/crud/role.py @@ -1,6 +1,3 @@ -# /server/app/crud/role.py -from typing import Optional, Any, Dict, Union - from sqlalchemy.orm import Session from public_api.shared_schemas import RoleCreate, RoleUpdate, Role as RoleSchema @@ -21,7 +18,7 @@ def create(self, db: Session, *, obj_in: RoleCreate) -> RoleSchema: return RoleSchema.model_validate(db_obj) - def update(self, db: Session, *, db_obj: Role, obj_in: Union[RoleUpdate, Dict[str, Any]]) -> RoleSchema: + def update(self, db: Session, *, db_obj: Role, obj_in: RoleUpdate | dict[str, any]) -> RoleSchema: if isinstance(obj_in, dict): update_data = obj_in else: @@ -36,7 +33,7 @@ def update(self, db: Session, *, db_obj: Role, obj_in: Union[RoleUpdate, Dict[st updated_role = super().update(db, db_obj=db_obj, obj_in=update_data) return RoleSchema.model_validate(updated_role) - def get_by_name(self, db: Session, *, name: str) -> Optional[RoleSchema]: + def get_by_name(self, db: Session, *, name: str) -> RoleSchema | None: current_role = db.query(Role).filter(Role.role_name == name).first() return RoleSchema.model_validate(current_role) if current_role else None diff --git a/server/app/crud/shipment.py b/server/app/crud/shipment.py index af28efb..83339a1 100644 --- a/server/app/crud/shipment.py +++ b/server/app/crud/shipment.py @@ -1,6 +1,3 @@ -# /server/app/crud/shipment.py -from typing import Optional, List - from fastapi import HTTPException from sqlalchemy.orm import Session @@ -29,7 +26,7 @@ def get_multi_with_filter(self, db: Session, *, shipments = query.offset(skip).limit(limit).all() return [ShipmentSchema.model_validate(shipment) for shipment in shipments] - def get_carrier_rates(self, db: Session, weight: float, dimensions: str, destination_zip: str) -> List[CarrierRate]: + def get_carrier_rates(self, db: Session, weight: float, dimensions: str, destination_zip: str) -> list[CarrierRate]: try: params = { "weight": weight, @@ -49,7 +46,7 @@ def get_carrier_rates(self, db: Session, weight: float, dimensions: str, destina raise HTTPException(status_code=401, detail="Invalid ShipEngine API key") raise e - def track(self, db: Session, *, shipment_id: int) -> Optional[ShipmentTracking]: + def track(self, db: Session, *, shipment_id: int) -> ShipmentTracking | None: shipment = self.get(db, id=shipment_id) if not shipment: raise HTTPException(status_code=404, detail="Shipment not found") diff --git a/server/app/crud/user.py b/server/app/crud/user.py index cdda966..b6a0f76 100644 --- a/server/app/crud/user.py +++ b/server/app/crud/user.py @@ -1,26 +1,27 @@ +# /server/app/crud/user.py +import time from datetime import timedelta, datetime -from typing import Optional, List, Any, Dict, Union from sqlalchemy.orm import Session, joinedload -from public_api.shared_schemas import UserCreate, UserUpdate, UserFilter +from public_api.shared_schemas import user as user_schemas from server.app.core.security import get_password_hash, verify_password from server.app.models import User, Permission, Role from .base import CRUDBase -class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): - - def get_by_email(self, db: Session, email: str) -> Optional[User]: - return db.query(User).filter(User.email == email).first() +class CRUDUser(CRUDBase[User, user_schemas.UserCreate, user_schemas.UserUpdate]): + def get_by_email(self, db: Session, email: str) -> user_schemas.UserSanitized | None: + user = db.query(User).filter(User.email == email).first() + return user_schemas.UserSanitized.model_validate(user) if user else None def get_multi_with_filters( self, db: Session, - filter_params: UserFilter, + filter_params: user_schemas.UserFilter, skip: int = 0, limit: int = 100 - ) -> List[User]: + ) -> list[user_schemas.UserSanitized]: query = db.query(User) if filter_params.role_id is not None: @@ -32,75 +33,82 @@ def get_multi_with_filters( if filter_params.is_active is not None: query = query.filter(User.is_active == filter_params.is_active) - return query.offset(skip).limit(limit).all() + users = query.offset(skip).limit(limit).all() + return [user_schemas.UserSanitized.model_validate(user) for user in users] - def get_by_username(self, db: Session, username: str) -> Optional[User]: - return db.query(User).filter(User.username == username).first() + def get_by_username(self, db: Session, username: str) -> user_schemas.UserSanitized | None: + user = db.query(User).filter(User.username == username).first() + return user_schemas.UserSanitized.model_validate(user) if user else None - def update(self, db: Session, *, db_obj: User, obj_in: Union[UserUpdate, Dict[str, Any]]) -> User: - if isinstance(obj_in, dict): - update_data = obj_in - else: - update_data = obj_in.model_dump(exclude_unset=True) + def update(self, db: Session, *, db_obj: User, obj_in: user_schemas.UserUpdate) -> user_schemas.UserSanitized: + update_data = obj_in.model_dump(exclude_unset=True) if update_data.get("password"): - update_data["password"] = get_password_hash(update_data["password"]) + hashed_password = get_password_hash(update_data["password"]) + del update_data["password"] + update_data["password"] = hashed_password return super().update(db, db_obj=db_obj, obj_in=update_data) - def authenticate(self, db: Session, *, email: str, password: str) -> Optional[User]: + def authenticate(self, db: Session, *, email: str, password: str) -> user_schemas.UserSanitized | None: user = db.query(User).filter(User.email == email).first() - if not user: + if not user or not verify_password(password, user.password): return None - if not verify_password(password, user.password): - return None - return user + + # updating last login timestamp + current_time = int(time.time()) + user.last_login = current_time + db.add(user) + db.commit() + db.refresh(user) + + return user_schemas.UserSanitized.model_validate(user) def is_active(self, user: User) -> bool: return user.is_active def is_admin(self, user: User) -> bool: - return user.role.role_name.lower() == "admin" + return user.role.role_name.lower() == user_schemas.RoleName.ADMIN.value - def change_role(self, db: Session, *, user_id: int, new_role_id: int) -> Optional[User]: + def change_role(self, db: Session, *, user_id: int, new_role_id: int) -> user_schemas.UserSanitized | None: user = self.get(db, id=user_id) if user: user.role_id = new_role_id db.commit() db.refresh(user) - return user + return user_schemas.UserSanitized.model_validate(user) if user else None - def set_reset_password_token(self, db: Session, *, user: User, token: str) -> User: + def set_reset_password_token(self, db: Session, *, user: User, token: str) -> user_schemas.UserSanitized: user.password_reset_token = token user.password_reset_expiration = int((datetime.utcnow() + timedelta(hours=1)).timestamp()) db.add(user) db.commit() db.refresh(user) - return user + return user_schemas.UserSanitized.model_validate(user) - def get_user_with_permissions(self, db: Session, user_id: int) -> Optional[User]: - return db.query(User).options( + def get_user_with_permissions(self, db: Session, user_id: int) -> user_schemas.UserWithPermissions | None: + user = db.query(User).options( joinedload(User.role).joinedload(Role.permissions) ).filter(User.id == user_id).first() + return user_schemas.UserWithPermissions.model_validate(user) if user else None - def update_user_permissions(self, db: Session, *, user_id: int, permission_ids: List[int]) -> User: + def update_user_permissions(self, + db: Session, + *, + user_id: int, + permission_ids: list[int]) -> user_schemas.UserWithPermissions: user = self.get(db, id=user_id) if not user: raise ValueError("User not found") - # Fetch the permissions based on the provided IDs permissions = db.query(Permission).filter(Permission.id.in_(permission_ids)).all() - - # Update the user's permissions user.permissions = permissions db.add(user) db.commit() db.refresh(user) - return user + return user_schemas.UserWithPermissions.model_validate(user) - def get_user_permissions(self, db: Session, user_id: int) -> List[Permission]: + def get_user_permissions(self, db: Session, user_id: int) -> list[Permission]: user = self.get_user_with_permissions(db, user_id) - if not user: - return [] - return user.permissions + return user.permissions if user else [] def check_permission(self, db: Session, user_id: int, name: str, action: str) -> bool: user_permissions = self.get_user_permissions(db, user_id) @@ -114,10 +122,10 @@ def check_permission(self, db: Session, user_id: int, name: str, action: str) -> return True return False - def get_all_permissions(self, db: Session) -> List[Permission]: + def get_all_permissions(self, db: Session) -> list[Permission]: return db.query(Permission).all() - def get_all_roles(self, db: Session) -> List[Role]: + def get_all_roles(self, db: Session) -> list[Role]: return db.query(Role).all() diff --git a/server/app/crud/yard_location.py b/server/app/crud/yard_location.py index 3db50b4..50e0b8c 100644 --- a/server/app/crud/yard_location.py +++ b/server/app/crud/yard_location.py @@ -1,20 +1,17 @@ -# /server/app/crud/yard_location.py -from typing import Optional, List - from sqlalchemy.orm import Session, selectinload -from server.app.models import YardLocation from public_api.shared_schemas import ( YardLocation as YardLocationSchema, YardLocationCreate, YardLocationUpdate, YardLocationFilter, YardLocationWithAppointments ) +from server.app.models import YardLocation from .base import CRUDBase class CRUDYardLocation(CRUDBase[YardLocation, YardLocationCreate, YardLocationUpdate]): def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, - filter_params: YardLocationFilter) -> List[YardLocationSchema]: + filter_params: YardLocationFilter) -> list[YardLocationSchema]: query = db.query(self.model) if filter_params.name: query = query.filter(YardLocation.name.ilike(f"%{filter_params.name}%")) @@ -25,7 +22,7 @@ def get_multi_with_filter(self, db: Session, *, skip: int = 0, limit: int = 100, locations = query.offset(skip).limit(limit).all() return [YardLocationSchema.model_validate(location) for location in locations] - def get_with_appointments(self, db: Session, id: int) -> Optional[YardLocationWithAppointments]: + def get_with_appointments(self, db: Session, id: int) -> YardLocationWithAppointments | None: location = (db.query(self.model) .filter(self.model.id == id) .options(selectinload(YardLocation.appointments)) diff --git a/server/app/crud/zone.py b/server/app/crud/zone.py index 5484203..e33563d 100644 --- a/server/app/crud/zone.py +++ b/server/app/crud/zone.py @@ -1,14 +1,11 @@ -# /server/app/crud/zone.py -from typing import Optional - from sqlalchemy.orm import Session, joinedload -from server.app.models import ( - Zone -) from public_api.shared_schemas import ( ZoneCreate, ZoneUpdate, LocationFilter, ZoneWithLocations, WarehouseLayout ) +from server.app.models import ( + Zone +) from .base import CRUDBase @@ -31,7 +28,7 @@ def get_warehouse_layout(self, db: Session) -> WarehouseLayout: def get_multi_with_locations( self, db: Session, skip: int = 0, limit: int = 100, - filter_params: Optional[LocationFilter] = None) -> list[ZoneWithLocations]: + filter_params: LocationFilter | None = None) -> list[ZoneWithLocations]: query = db.query(Zone).options(joinedload(Zone.locations)) if filter_params: @@ -41,7 +38,7 @@ def get_multi_with_locations( zones = query.offset(skip).limit(limit).all() return [ZoneWithLocations.model_validate(zone) for zone in zones] - def get_with_locations(self, db: Session, id: int) -> Optional[ZoneWithLocations]: + def get_with_locations(self, db: Session, id: int) -> ZoneWithLocations | None: zone = db.query(Zone).filter(Zone.id == id).options(joinedload(Zone.locations)).first() return ZoneWithLocations.model_validate(zone) if zone else None diff --git a/server/nexusware.db b/server/nexusware.db index 7cd4864..e01dd06 100644 Binary files a/server/nexusware.db and b/server/nexusware.db differ