Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

25143 - remove product / add previously approved product support #3200

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions auth-api/src/auth_api/exceptions/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class Error(Enum):
"""Error Codes."""

INVALID_ORG = "The organization ID is in an incorrect format.", HTTPStatus.BAD_REQUEST
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this consolidate duplicate code in the org_products route for validating orgId format

INVALID_INPUT = "Invalid input, please check.", HTTPStatus.BAD_REQUEST
DATA_NOT_FOUND = "No matching record found.", HTTPStatus.NOT_FOUND
DATA_ALREADY_EXISTS = "The data you want to insert already exists.", HTTPStatus.BAD_REQUEST
Expand Down
3 changes: 2 additions & 1 deletion auth-api/src/auth_api/models/product_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

The ProductSubscription object connects Org models to one or more ProductSubscription models.
"""
from typing import Self

from sql_versioning import Versioned
from sqlalchemy import Column, ForeignKey, Integer, and_
Expand Down Expand Up @@ -45,7 +46,7 @@ def find_by_org_ids(cls, org_ids, valid_statuses=VALID_SUBSCRIPTION_STATUSES):
).all()

@classmethod
def find_by_org_id_product_code(cls, org_id: int, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES):
def find_by_org_id_product_code(cls, org_id: int, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES) -> Self:
"""Find an product subscription instance that matches the provided id."""
return cls.query.filter(
and_(
Expand Down
21 changes: 19 additions & 2 deletions auth-api/src/auth_api/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""This model manages a Task item in the Auth Service."""
import datetime as dt
from typing import Self

import pytz
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, text
Expand Down Expand Up @@ -93,14 +94,14 @@ def fetch_tasks(cls, task_search: TaskSearch):
return pagination.items, pagination.total

@classmethod
def find_by_task_id(cls, task_id: int):
def find_by_task_id(cls, task_id: int) -> Self:
"""Find a task instance that matches the provided id."""
return db.session.query(Task).filter_by(id=int(task_id or -1)).first()

@classmethod
def find_by_task_relationship_id(
cls, relationship_id: int, task_relationship_type: str, task_status: str = TaskStatus.OPEN.value
):
) -> Self:
"""Find a task instance that related to the relationship id ( may be an ORG or a PRODUCT."""
return (
db.session.query(Task)
Expand All @@ -112,6 +113,22 @@ def find_by_task_relationship_id(
.first()
)

@classmethod
def find_by_incomplete_task_relationship_id(
cls, relationship_id: int, task_relationship_type: str, relationship_status: str = None
) -> Self:
"""Find a task instance that related to the relationship id ( may be an ORG or a PRODUCT) that is incomplete."""
query = db.session.query(Task).filter(
Task.relationship_id == int(relationship_id or -1),
Task.relationship_type == task_relationship_type,
Task.status.in_((TaskStatus.OPEN.value, TaskStatus.HOLD.value)),
)

if relationship_status is not None:
query = query.filter(Task.relationship_status == relationship_status)

return query.first()

@classmethod
def find_by_task_for_account(cls, org_id: int, status):
"""Find a task instance that matches the provided id."""
Expand Down
34 changes: 24 additions & 10 deletions auth-api/src/auth_api/resources/v1/org_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flask import Blueprint, g, request
from flask_cors import cross_origin

from auth_api.exceptions import BusinessException
from auth_api.exceptions import BusinessException, Error
from auth_api.schemas import utils as schema_utils
from auth_api.services import Product as ProductService
from auth_api.utils.auth import jwt as _jwt
Expand All @@ -34,10 +34,8 @@
def get_org_product_subscriptions(org_id):
"""GET a new product subscription to the org using the request body."""

if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0:
return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST

try:
validate_organization(org_id)
include_hidden = request.args.get("include_hidden", None) == "true" # used by NDS
response, status = (
json.dumps(ProductService.get_all_product_subscription(org_id=int(org_id), include_hidden=include_hidden)),
Expand All @@ -54,15 +52,13 @@ def get_org_product_subscriptions(org_id):
def post_org_product_subscription(org_id):
"""Post a new product subscription to the org using the request body."""

if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0:
return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST

request_json = request.get_json()
valid_format, errors = schema_utils.validate(request_json, "org_product_subscription")
if not valid_format:
return {"message": schema_utils.serialize(errors)}, HTTPStatus.BAD_REQUEST

try:
validate_organization(org_id)
roles = g.jwt_oidc_token_info.get("realm_access").get("roles")
subscriptions = ProductService.create_product_subscription(
int(org_id), request_json, skip_auth=Role.SYSTEM.value in roles, auto_approve=Role.SYSTEM.value in roles
Expand All @@ -80,17 +76,35 @@ def post_org_product_subscription(org_id):
def patch_org_product_subscription(org_id):
"""Patch existing product subscription to resubmit it for review."""

if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0:
return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST

request_json = request.get_json()
valid_format, errors = schema_utils.validate(request_json, "org_product_subscription")
if not valid_format:
return {"message": schema_utils.serialize(errors)}, HTTPStatus.BAD_REQUEST

try:
validate_organization(org_id)
subscriptions = ProductService.resubmit_product_subscription(int(org_id), request_json)
response, status = {"subscriptions": subscriptions}, HTTPStatus.OK
except BusinessException as exception:
response, status = {"code": exception.code, "message": exception.message}, exception.status_code
return response, status


@bp.route("/<string:product_code>", methods=["DELETE", "OPTIONS"])
@cross_origin(origins="*", methods=["DELETE"])
@_jwt.has_one_of_roles([Role.STAFF_CREATE_ACCOUNTS.value, Role.PUBLIC_USER.value, Role.SYSTEM.value])
def delete_product_subscription(org_id, product_code):
"""Delete existing product subscription."""

try:
validate_organization(org_id)
subscriptions = ProductService.remove_product_subscription(int(org_id), product_code)
response, status = {"subscriptions": subscriptions}, HTTPStatus.OK
except BusinessException as exception:
response, status = {"code": exception.code, "message": exception.message}, exception.status_code
return response, status


def validate_organization(org_id):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mini helper added so routes can re-use it

if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0:
raise BusinessException(Error.INVALID_ORG, None)
70 changes: 65 additions & 5 deletions auth-api/src/auth_api/services/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,26 @@ def resubmit_product_subscription(org_id, subscription_data: Dict[str, Any], ski

return Product.get_all_product_subscription(org_id=org_id, skip_auth=True)

@staticmethod
def _is_previously_approved(org_id: int, product_code: str):
"""Check if this product has a task that was previously approved."""
inactive_sub = ProductSubscriptionModel.find_by_org_id_product_code(
org_id=org_id, product_code=product_code, valid_statuses=(ProductSubscriptionStatus.INACTIVE.value,)
)
if not inactive_sub:
return False, None

task = TaskModel.find_by_task_relationship_id(
inactive_sub.id, TaskRelationshipType.PRODUCT.value, TaskStatus.COMPLETED.value
)
if task is None or (
task.relationship_status != TaskRelationshipStatus.ACTIVE.value
and task.action == TaskAction.PRODUCT_REVIEW.value
):
return False, None

return True, inactive_sub

@staticmethod
def create_product_subscription(
org_id,
Expand Down Expand Up @@ -181,10 +201,13 @@ def create_product_subscription(
and org.type_code not in PREMIUM_ORG_TYPES
):
continue
previously_approved, inactive_sub = Product._is_previously_approved(org_id, product_code)
if previously_approved:
auto_approve = True

subscription_status = Product.find_subscription_status(org, product_model, auto_approve)
product_subscription = Product._subscribe_and_publish_activity(
org_id, product_code, subscription_status, product_model.description
org_id, product_code, subscription_status, product_model.description, inactive_sub
)

# If there is a linked product, add subscription to that too.
Expand Down Expand Up @@ -229,6 +252,32 @@ def create_product_subscription(

return Product.get_all_product_subscription(org_id=org_id, skip_auth=True)

@staticmethod
def remove_product_subscription(org_id: int, product_code: str, skip_auth=False):
"""Deactivate org product subscription by code."""
org: OrgModel = OrgModel.find_by_org_id(org_id)
if not org:
raise BusinessException(Error.DATA_NOT_FOUND, None)

if not skip_auth:
check_auth(one_of_roles=(*CLIENT_ADMIN_ROLES, STAFF), org_id=org_id)

existing_sub = ProductSubscriptionModel.find_by_org_id_product_code(org_id, product_code)

if existing_sub:
existing_sub.status_code = ProductSubscriptionStatus.INACTIVE.value
existing_sub.save()

pending_task = TaskModel.find_by_incomplete_task_relationship_id(
relationship_id=existing_sub.id,
task_relationship_type=TaskRelationshipType.PRODUCT.value,
relationship_status=ProductSubscriptionStatus.PENDING_STAFF_REVIEW.value,
)
if pending_task:
pending_task.delete()

return Product.get_all_product_subscription(org_id=org_id, skip_auth=True)

@staticmethod
def _send_product_subscription_confirmation(product_notification_info: ProductNotificationInfo, org_id: int):
admin_emails = UserService.get_admin_emails_for_org(org_id)
Expand Down Expand Up @@ -256,11 +305,22 @@ def _update_parent_subscription(org_id, sub_product_model, subscription_status):

@staticmethod
def _subscribe_and_publish_activity(
org_id: int, product_code: str, status_code: str, product_model_description: str
org_id: int,
product_code: str,
status_code: str,
product_model_description: str,
inactive_sub: ProductSubscriptionModel = None,
):
subscription = ProductSubscriptionModel(
org_id=org_id, product_code=product_code, status_code=status_code
).flush()
subscription = None
if inactive_sub:
subscription = inactive_sub
subscription.status_code = status_code
subscription.flush()
else:
subscription = ProductSubscriptionModel(
org_id=org_id, product_code=product_code, status_code=status_code
).flush()

if status_code == ProductSubscriptionStatus.ACTIVE.value:
ActivityLogPublisher.publish_activity(
Activity(org_id, ActivityAction.ADD_PRODUCT_AND_SERVICE.value, name=product_model_description)
Expand Down
4 changes: 4 additions & 0 deletions auth-api/tests/unit/api/test_cors_preflight.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def test_preflight_org_products(app, client, jwt, session):
assert rv.status_code == HTTPStatus.OK
assert_access_control_headers(rv, "*", "GET, PATCH, POST")

rv = client.options("/api/v1/orgs/1/products/ABC", headers={"Access-Control-Request-Method": "DELETE"})
assert rv.status_code == HTTPStatus.OK
assert_access_control_headers(rv, "*", "DELETE")


def test_preflight_org_permissions(app, client, jwt, session):
"""Assert preflight responses for org permissions are correct."""
Expand Down
Loading
Loading