From 66ea7e29fb838efb18c7c2f6af5cd7f7cccff5b2 Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Fri, 21 Jun 2024 09:57:53 +0100 Subject: [PATCH 1/6] AAP-25479: Gracefully handle 400 Bad Request from WCA requests --- ansible_ai_connect/ai/api/api_wrapper.py | 119 +++++++++++++++ .../pipelines/completion_stages/inference2.py | 139 ++++++++++++++++++ .../ai/api/pipelines/completions.py | 6 +- ansible_ai_connect/ai/api/tests/test_views.py | 4 +- 4 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 ansible_ai_connect/ai/api/api_wrapper.py create mode 100644 ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py diff --git a/ansible_ai_connect/ai/api/api_wrapper.py b/ansible_ai_connect/ai/api/api_wrapper.py new file mode 100644 index 000000000..ff1456496 --- /dev/null +++ b/ansible_ai_connect/ai/api/api_wrapper.py @@ -0,0 +1,119 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import json +import logging +from typing import Optional +from uuid import UUID + +from django.conf import settings + +from ansible_ai_connect.ai.api.exceptions import ( + ModelTimeoutException, + ServiceUnavailable, + WcaBadRequestException, + WcaCloudflareRejectionException, + WcaEmptyResponseException, + WcaInvalidModelIdException, + WcaKeyNotFoundException, + WcaModelIdNotFoundException, + WcaNoDefaultModelIdException, + WcaSuggestionIdCorrelationFailureException, + WcaUserTrialExpiredException, +) +from ansible_ai_connect.ai.api.model_client.exceptions import ( + ModelTimeoutError, + WcaBadRequest, + WcaCloudflareRejection, + WcaEmptyResponse, + WcaInvalidModelId, + WcaKeyNotFound, + WcaModelIdNotFound, + WcaNoDefaultModelId, + WcaSuggestionIdCorrelationFailure, + WcaUserTrialExpired, +) + +logger = logging.getLogger(__name__) + + +def call(api_type: str, identifier: Optional[UUID]): + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + value = func(*args, **kwargs) + return value + except ModelTimeoutError as e: + logger.warning( + f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} " + f"seconds (per task) for {api_type}: {identifier}" + ) + raise ModelTimeoutException(cause=e) + + except WcaBadRequest as e: + logger.error( + f"bad request from WCA for completion for {api_type}: {identifier}:" + f" {json.dumps(e.json_response)}" + ) + raise WcaBadRequestException(cause=e) + + except WcaInvalidModelId as e: + logger.info(f"WCA Model ID is invalid for {api_type}: {identifier}") + raise WcaInvalidModelIdException(cause=e) + + except WcaKeyNotFound as e: + logger.info( + f"A WCA Api Key was expected but not found for {api_type}: {identifier}" + ) + raise WcaKeyNotFoundException(cause=e) + + except WcaNoDefaultModelId as e: + logger.info(f"No default WCA Model ID was found for {api_type}: {identifier}") + raise WcaNoDefaultModelIdException(cause=e) + + except WcaModelIdNotFound as e: + logger.info( + f"A WCA Model ID was expected but not found for {api_type}: {identifier}" + ) + raise WcaModelIdNotFoundException(cause=e) + + except WcaSuggestionIdCorrelationFailure as e: + logger.info( + f"WCA Request/Response SuggestionId correlation failed for " + f"{api_type}: {identifier} and x_request_id: {e.x_request_id}" + ) + raise WcaSuggestionIdCorrelationFailureException(cause=e) + + except WcaEmptyResponse as e: + logger.info(f"WCA returned an empty response for suggestion {identifier}") + raise WcaEmptyResponseException(cause=e) + + except WcaCloudflareRejection as e: + logger.exception(f"Cloudflare rejected the request for {api_type}: {identifier}") + raise WcaCloudflareRejectionException(cause=e) + + except WcaUserTrialExpired as e: + logger.exception(f"User trial expired, when requesting {api_type}: {identifier}") + raise WcaUserTrialExpiredException(cause=e) + + except Exception as e: + logger.exception(f"error requesting completion for {api_type}: {identifier}") + raise ServiceUnavailable(cause=e) + + return wrapper + + return decorator diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py new file mode 100644 index 000000000..90a4da264 --- /dev/null +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py @@ -0,0 +1,139 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from string import Template + +from ansible_anonymizer import anonymizer +from django.apps import apps +from django_prometheus.conf import NAMESPACE +from prometheus_client import Histogram + +from ansible_ai_connect.ai.api.api_wrapper import call +from ansible_ai_connect.ai.api.data.data_model import ModelMeshPayload +from ansible_ai_connect.ai.api.exceptions import ( + BaseWisdomAPIException, + WcaUserTrialExpiredException, + process_error_count, +) +from ansible_ai_connect.ai.api.pipelines.common import PipelineElement +from ansible_ai_connect.ai.api.pipelines.completion_context import CompletionContext +from ansible_ai_connect.ai.api.utils.segment import send_segment_event +from ansible_ai_connect.ai.feature_flags import FeatureFlags + +logger = logging.getLogger(__name__) + +feature_flags = FeatureFlags() + +completions_hist = Histogram( + "model_prediction_latency_seconds", + "Histogram of model prediction processing time", + namespace=NAMESPACE, +) + + +class InferenceStage2(PipelineElement): + def process(self, context: CompletionContext) -> None: + + payload = context.payload + suggestion_id = payload.suggestionId + + @call("suggestions", suggestion_id) + def get_predictions() -> None: + request = context.request + model_mesh_client = apps.get_app_config("ai").model_mesh_client + # We have a little inconsistency of the "model" term throughout the application: + # - FeatureFlags use 'model_name' + # - ModelMeshClient uses 'model_id' + # - Public completion API uses 'model' + # - Segment Events use 'modelName' + model_id = payload.model + + model_mesh_payload = ModelMeshPayload( + instances=[ + { + "prompt": payload.prompt, + "context": payload.context, + "suggestionId": str(suggestion_id), + } + ] + ) + data = model_mesh_payload.dict() + logger.debug(f"input to inference for suggestion id {suggestion_id}:\n{data}") + + predictions = None + exception = None + event = None + event_name = None + start_time = time.time() + + try: + predictions = model_mesh_client.infer( + request, data, model_id=model_id, suggestion_id=suggestion_id + ) + model_id = predictions.get("model_id", model_id) + + except WcaUserTrialExpiredException as e: + exception = e + event = { + "type": "prediction", + "modelName": model_id, + "suggestionId": str(suggestion_id), + } + event_name = "trialExpired" + raise + + except Exception as e: + exception = e + raise + finally: + duration = round((time.time() - start_time) * 1000, 2) + completions_hist.observe(duration / 1000) # millisec back to seconds + anonymized_predictions = anonymizer.anonymize_struct( + predictions, value_template=Template("{{ _${variable_name}_ }}") + ) + # If an exception was thrown during the backend call, try to get the model ID + # that is contained in the exception. + if exception: + process_error_count.labels(stage="prediction").inc() + model_id_in_exception = BaseWisdomAPIException.get_model_id_from_exception( + exception + ) + if model_id_in_exception: + model_id = model_id_in_exception + if event: + event["modelName"] = model_id + else: + event = { + "duration": duration, + "exception": exception is not None, + "modelName": model_id, + "problem": None if exception is None else exception.__class__.__name__, + "request": data, + "response": anonymized_predictions, + "suggestionId": str(suggestion_id), + } + event_name = event_name if event_name else "prediction" + send_segment_event(event, event_name, request.user) + + logger.debug( + f"response from inference for suggestion id {suggestion_id}:\n{predictions}" + ) + + context.model_id = model_id + context.predictions = predictions + context.anonymized_predictions = anonymized_predictions + + get_predictions() diff --git a/ansible_ai_connect/ai/api/pipelines/completions.py b/ansible_ai_connect/ai/api/pipelines/completions.py index 17c8bc09a..838537fae 100644 --- a/ansible_ai_connect/ai/api/pipelines/completions.py +++ b/ansible_ai_connect/ai/api/pipelines/completions.py @@ -22,8 +22,8 @@ from ansible_ai_connect.ai.api.pipelines.completion_stages.deserialise import ( DeserializeStage, ) -from ansible_ai_connect.ai.api.pipelines.completion_stages.inference import ( - InferenceStage, +from ansible_ai_connect.ai.api.pipelines.completion_stages.inference2 import ( + InferenceStage2, ) from ansible_ai_connect.ai.api.pipelines.completion_stages.post_process import ( PostProcessStage, @@ -45,7 +45,7 @@ def __init__(self, request: Request): [ DeserializeStage(), PreProcessStage(), - InferenceStage(), + InferenceStage2(), PostProcessStage(), ResponseStage(), ], diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index 673ce7652..b5f3423db 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -680,8 +680,8 @@ def test_wca_completion_request_id_correlation_failure(self): properties = event["properties"] self.assertTrue(properties["exception"]) self.assertEqual(properties["problem"], "WcaSuggestionIdCorrelationFailure") - self.assertInLog(f"suggestion_id: '{DEFAULT_SUGGESTION_ID}'", log) - self.assertInLog(f"x_request_id: '{x_request_id}'", log) + self.assertInLog(f"suggestions: {DEFAULT_SUGGESTION_ID}", log) + self.assertInLog(f"x_request_id: {x_request_id}", log) @override_settings(WCA_SECRET_DUMMY_SECRETS="1:valid") @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") From bd91c7200983dd62bf2304a6c75e6a29ea573dca Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Fri, 21 Jun 2024 11:10:06 +0100 Subject: [PATCH 2/6] ContentMatch refactor --- ansible_ai_connect/ai/api/api_wrapper.py | 4 +- .../pipelines/completion_stages/inference2.py | 2 +- ansible_ai_connect/ai/api/tests/test_views.py | 78 ++++- ansible_ai_connect/ai/api/urls.py | 5 + ansible_ai_connect/ai/api/views2.py | 277 ++++++++++++++++++ 5 files changed, 358 insertions(+), 8 deletions(-) create mode 100644 ansible_ai_connect/ai/api/views2.py diff --git a/ansible_ai_connect/ai/api/api_wrapper.py b/ansible_ai_connect/ai/api/api_wrapper.py index ff1456496..21de1f808 100644 --- a/ansible_ai_connect/ai/api/api_wrapper.py +++ b/ansible_ai_connect/ai/api/api_wrapper.py @@ -15,8 +15,6 @@ import functools import json import logging -from typing import Optional -from uuid import UUID from django.conf import settings @@ -49,7 +47,7 @@ logger = logging.getLogger(__name__) -def call(api_type: str, identifier: Optional[UUID]): +def call(api_type: str, identifier: str): def decorator(func): @functools.wraps(func) diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py index 90a4da264..9224fb9fe 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py @@ -50,7 +50,7 @@ def process(self, context: CompletionContext) -> None: payload = context.payload suggestion_id = payload.suggestionId - @call("suggestions", suggestion_id) + @call("suggestions", str(suggestion_id)) def get_predictions() -> None: request = context.request model_mesh_client = apps.get_app_config("ai").model_mesh_client diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index b5f3423db..5e34b043b 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -2186,8 +2186,78 @@ def setUp(self): self.model_client.get_token = Mock(return_value={"access_token": "abc"}) self.model_client.get_api_key = Mock(return_value="org-api-key") +<<<<<<< HEAD:ansible_ai_connect/ai/api/tests/test_views.py +======= + self.search_response = { + "attributions": [ + { + "repo_name": repo_name, + "repo_url": repo_url, + "path": path, + "license": license, + "data_source": DataSource.GALAXY_R, + "ansible_type": AnsibleType.UNKNOWN, + "score": 0.0, + }, + ], + "meta": { + "encode_duration": 1000, + "search_duration": 2000, + }, + } + + @override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True) + @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") + @patch("ansible_ai_connect.ai.api.views2.send_segment_event") + @patch("ansible_ai_connect.ai.search.search") + def test_wca_contentmatch_segment_events_with_unseated_user( + self, mock_search, mock_send_segment_event + ): + self.user.rh_user_has_seat = False + + mock_search.return_value = self.search_response + + r = self.client.post(reverse("contentmatches"), self.payload) + self.assertEqual(r.status_code, HTTPStatus.OK) + + event = { + "exception": False, + "modelName": "", + "problem": None, + "response": { + "contentmatches": [ + { + "contentmatch": [ + { + "repo_name": "robertdebock.nginx", + "repo_url": "https://galaxy.ansible.com/robertdebock/nginx", + "path": "tasks/main.yml", + "license": "apache-2.0", + "score": 0.0, + "data_source_description": "Ansible Galaxy roles", + } + ] + } + ] + }, + "metadata": [{"encode_duration": 1000, "search_duration": 2000}], + } + + event_request = { + "suggestions": [ + "\n - name: install nginx on RHEL\n become: true\n " + "ansible.builtin.package:\n name: nginx\n state: present\n" + ] + } + + actual_event = mock_send_segment_event.call_args_list[0][0][0] + + self.assertTrue(event.items() <= actual_event.items()) + self.assertTrue(event_request.items() <= actual_event.get("request").items()) + +>>>>>>> 4a5b54df (ContentMatch refactor):ansible_wisdom/ai/api/tests/test_views.py @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") - @patch("ansible_ai_connect.ai.api.views.send_segment_event") + @patch("ansible_ai_connect.ai.api.views2.send_segment_event") def test_wca_contentmatch_segment_events_with_seated_user(self, mock_send_segment_event): self.user.rh_user_has_seat = True self.model_client.get_model_id = Mock(return_value="model-id") @@ -2234,7 +2304,7 @@ def test_wca_contentmatch_segment_events_with_seated_user(self, mock_send_segmen self.assertTrue(event_request.items() <= actual_event.get("request").items()) @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") - @patch("ansible_ai_connect.ai.api.views.send_segment_event") + @patch("ansible_ai_connect.ai.api.views2.send_segment_event") def test_wca_contentmatch_segment_events_with_invalid_modelid_error( self, mock_send_segment_event ): @@ -2276,7 +2346,7 @@ def test_wca_contentmatch_segment_events_with_invalid_modelid_error( self.assertTrue(event_request.items() <= actual_event.get("request").items()) @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") - @patch("ansible_ai_connect.ai.api.views.send_segment_event") + @patch("ansible_ai_connect.ai.api.views2.send_segment_event") def test_wca_contentmatch_segment_events_with_empty_response_error( self, mock_send_segment_event ): @@ -2321,7 +2391,7 @@ def test_wca_contentmatch_segment_events_with_empty_response_error( self.assertTrue(event_request.items() <= actual_event.get("request").items()) @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") - @patch("ansible_ai_connect.ai.api.views.send_segment_event") + @patch("ansible_ai_connect.ai.api.views2.send_segment_event") def test_wca_contentmatch_segment_events_with_key_error(self, mock_send_segment_event): self.user.rh_user_has_seat = True self.model_client.get_api_key = Mock(side_effect=WcaKeyNotFound) diff --git a/ansible_ai_connect/ai/api/urls.py b/ansible_ai_connect/ai/api/urls.py index a52201ffd..fa9d26a31 100644 --- a/ansible_ai_connect/ai/api/urls.py +++ b/ansible_ai_connect/ai/api/urls.py @@ -14,7 +14,12 @@ from django.urls import path +<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py from .views import Completions, ContentMatches, Explanation, Feedback, Generation +======= +from .views import Attributions, Completions, Explanation, Feedback, Generation +from .views2 import ContentMatches +>>>>>>> 4a5b54df (ContentMatch refactor):ansible_wisdom/ai/api/urls.py urlpatterns = [ path("completions/", Completions.as_view(), name="completions"), diff --git a/ansible_ai_connect/ai/api/views2.py b/ansible_ai_connect/ai/api/views2.py new file mode 100644 index 000000000..341cf009a --- /dev/null +++ b/ansible_ai_connect/ai/api/views2.py @@ -0,0 +1,277 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from http import HTTPStatus + +from django.apps import apps +from django.conf import settings +from drf_spectacular.utils import OpenApiResponse, extend_schema +from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope +from rest_framework import permissions +from rest_framework import status as rest_framework_status +from rest_framework.exceptions import ValidationError +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response + +from ansible_ai_connect.ai.api.api_wrapper import call +from ansible_ai_connect.ai.api.exceptions import ( + BaseWisdomAPIException, + InternalServerError, + WcaUserTrialExpiredException, + process_error_count, +) +from ansible_ai_connect.users.models import User + +from .. import search as ai_search +from .data.data_model import ( + AttributionsResponseDto, + ContentMatchPayloadData, + ContentMatchResponseDto, +) +from .permissions import AcceptedTermsPermission, BlockUserWithoutSeat, IsAAPLicensed +from .serializers import ContentMatchRequestSerializer, ContentMatchResponseSerializer +from .utils.segment import send_segment_event +from .views import contentmatch_encoding_hist, contentmatch_search_hist + +logger = logging.getLogger(__name__) + + +class ContentMatches(GenericAPIView): + """ + Returns content matches that were the highest likelihood sources for a given code suggestion. + """ + + serializer_class = ContentMatchRequestSerializer + + permission_classes = ( + [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + IsAAPLicensed, + ] + if settings.DEPLOYMENT_MODE == "onprem" + else [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + AcceptedTermsPermission, + BlockUserWithoutSeat, + ] + ) + + required_scopes = ["read", "write"] + + throttle_cache_key_suffix = "_contentmatches" + + @extend_schema( + request=ContentMatchRequestSerializer, + responses={ + 200: ContentMatchResponseSerializer, + 400: OpenApiResponse(description="Bad Request"), + 401: OpenApiResponse(description="Unauthorized"), + 429: OpenApiResponse(description="Request was throttled"), + 503: OpenApiResponse(description="Service Unavailable"), + }, + summary="Code suggestion attributions", + ) + def post(self, request) -> Response: + request_serializer = self.get_serializer(data=request.data) + request_serializer.is_valid(raise_exception=True) + + request_data = request_serializer.validated_data + suggestion_id = str(request_data.get("suggestionId", "")) + model_id = str(request_data.get("model", "")) + + try: + if request.user.rh_user_has_seat: + response_serializer = self.perform_content_matching( + model_id, suggestion_id, request.user, request_data + ) + else: + response_serializer = self.perform_search(request_data, request.user) + return Response(response_serializer.data, status=rest_framework_status.HTTP_200_OK) + except Exception: + logger.exception("Error requesting content matches") + raise + + def perform_content_matching( + self, + model_id: str, + suggestion_id: str, + user: User, + request_data, + ): + _model_id = model_id + + @call("suggestions", suggestion_id) + def get_content_matches() -> ContentMatchResponseSerializer: + __model_id = _model_id + model_mesh_client = apps.get_app_config("ai").model_mesh_client + user_id = user.uuid + content_match_data: ContentMatchPayloadData = { + "suggestions": request_data.get("suggestions", []), + "user_id": str(user_id) if user_id else None, + "rh_user_has_seat": user.rh_user_has_seat, + "organization_id": user.org_id, + "suggestionId": suggestion_id, + } + logger.debug( + f"input to content matches for suggestion id {suggestion_id}:\n{content_match_data}" + ) + + exception = None + event = None + event_name = None + start_time = time.time() + response_serializer = None + metadata = [] + + try: + __model_id, client_response = model_mesh_client.codematch( + content_match_data, __model_id + ) + + response_data = {"contentmatches": []} + + for response_item in client_response: + content_match_dto = ContentMatchResponseDto(**response_item) + response_data["contentmatches"].append(content_match_dto.content_matches) + metadata.append(content_match_dto.meta) + + contentmatch_encoding_hist.observe(content_match_dto.encode_duration / 1000) + contentmatch_search_hist.observe(content_match_dto.search_duration / 1000) + + response_serializer = ContentMatchResponseSerializer(data=response_data) + response_serializer.is_valid(raise_exception=True) + + except ValidationError: + process_error_count.labels( + stage="contentmatch-response_serialization_validation" + ).inc() + logger.exception(f"error serializing final response for suggestion {suggestion_id}") + raise InternalServerError + + except WcaUserTrialExpiredException as e: + exception = e + event = { + "type": "prediction", + "modelName": __model_id, + "suggestionId": str(suggestion_id), + } + event_name = "trialExpired" + raise + + except Exception as e: + exception = e + raise + + finally: + duration = round((time.time() - start_time) * 1000, 2) + if exception: + model_id_in_exception = BaseWisdomAPIException.get_model_id_from_exception( + exception + ) + if model_id_in_exception: + __model_id = model_id_in_exception + if event: + event["modelName"] = __model_id + send_segment_event(event, event_name, user) + else: + self.write_to_segment( + request_data, + duration, + exception, + metadata, + __model_id, + response_serializer.data if response_serializer else {}, + suggestion_id, + user, + ) + + return response_serializer + + return get_content_matches() + + def perform_search(self, request_data, user: User): + suggestion_id = str(request_data.get("suggestionId", "")) + response_serializer = None + + exception = None + start_time = time.time() + metadata = [] + model_name = "" + + try: + suggestion = request_data["suggestions"][0] + response_item = ai_search.search(suggestion) + + attributions_dto = AttributionsResponseDto(**response_item) + response_data = {"contentmatches": []} + response_data["contentmatches"].append(attributions_dto.content_matches) + metadata.append(attributions_dto.meta) + + try: + response_serializer = ContentMatchResponseSerializer(data=response_data) + response_serializer.is_valid(raise_exception=True) + except Exception: + process_error_count.labels(stage="attr-response_serialization_validation").inc() + logger.exception(f"Error serializing final response for suggestion {suggestion_id}") + raise InternalServerError + + except Exception as e: + exception = e + logger.exception("Failed to search for attributions for content matching") + return Response( + {"message": "Unable to complete the request"}, status=HTTPStatus.SERVICE_UNAVAILABLE + ) + finally: + duration = round((time.time() - start_time) * 1000, 2) + self.write_to_segment( + request_data, + duration, + exception, + metadata, + model_name, + response_serializer.data if response_serializer else {}, + suggestion_id, + user, + ) + + return response_serializer + + def write_to_segment( + self, + request_data, + duration, + exception, + metadata, + model_id, + response_data, + suggestion_id, + user, + ): + event = { + "duration": duration, + "exception": exception is not None, + "modelName": model_id, + "problem": None if exception is None else exception.__class__.__name__, + "request": request_data, + "response": response_data, + "suggestionId": str(suggestion_id), + "rh_user_has_seat": user.rh_user_has_seat, + "rh_user_org_id": user.org_id, + "metadata": metadata, + } + send_segment_event(event, "contentmatch", user) From 73a3aeafc8cb9c167f36acf6250c7d6bafc03f8d Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Fri, 21 Jun 2024 12:03:16 +0100 Subject: [PATCH 3/6] Generations refactor --- ansible_ai_connect/ai/api/api_wrapper.py | 12 +- .../ai/api/model_client/wca_client.py | 3 + .../pipelines/completion_stages/inference2.py | 2 +- ansible_ai_connect/ai/api/urls.py | 6 + ansible_ai_connect/ai/api/views2.py | 2 +- ansible_ai_connect/ai/api/views3.py | 165 ++++++++++++++++++ 6 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 ansible_ai_connect/ai/api/views3.py diff --git a/ansible_ai_connect/ai/api/api_wrapper.py b/ansible_ai_connect/ai/api/api_wrapper.py index 21de1f808..006ed9ce5 100644 --- a/ansible_ai_connect/ai/api/api_wrapper.py +++ b/ansible_ai_connect/ai/api/api_wrapper.py @@ -15,8 +15,10 @@ import functools import json import logging +from typing import Callable from django.conf import settings +from rest_framework.exceptions import ValidationError from ansible_ai_connect.ai.api.exceptions import ( ModelTimeoutException, @@ -47,12 +49,13 @@ logger = logging.getLogger(__name__) -def call(api_type: str, identifier: str): +def call(api_type: str, identifier_provider: Callable[[], str]): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: + identifier = identifier_provider() value = func(*args, **kwargs) return value except ModelTimeoutError as e: @@ -108,6 +111,13 @@ def wrapper(*args, **kwargs): logger.exception(f"User trial expired, when requesting {api_type}: {identifier}") raise WcaUserTrialExpiredException(cause=e) + except ValidationError as e: + logger.exception( + f"An exception {e.__class__} occurred " + f"during validation of {api_type}: {identifier}" + ) + raise + except Exception as e: logger.exception(f"error requesting completion for {api_type}: {identifier}") raise ServiceUnavailable(cause=e) diff --git a/ansible_ai_connect/ai/api/model_client/wca_client.py b/ansible_ai_connect/ai/api/model_client/wca_client.py index 44e2bc377..73191a4ca 100644 --- a/ansible_ai_connect/ai/api/model_client/wca_client.py +++ b/ansible_ai_connect/ai/api/model_client/wca_client.py @@ -491,7 +491,10 @@ def generate_playbook( headers=headers, json=data, ) + context = Context(model_id, result, False) + InferenceResponseChecks().run_checks(context) result.raise_for_status() + response = json.loads(result.text) playbook = response["playbook"] diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py index 9224fb9fe..50fc236b5 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/inference2.py @@ -50,7 +50,7 @@ def process(self, context: CompletionContext) -> None: payload = context.payload suggestion_id = payload.suggestionId - @call("suggestions", str(suggestion_id)) + @call("suggestions", lambda: str(suggestion_id)) def get_predictions() -> None: request = context.request model_mesh_client = apps.get_app_config("ai").model_mesh_client diff --git a/ansible_ai_connect/ai/api/urls.py b/ansible_ai_connect/ai/api/urls.py index fa9d26a31..bcb080eaa 100644 --- a/ansible_ai_connect/ai/api/urls.py +++ b/ansible_ai_connect/ai/api/urls.py @@ -14,12 +14,18 @@ from django.urls import path +<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py <<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py from .views import Completions, ContentMatches, Explanation, Feedback, Generation ======= from .views import Attributions, Completions, Explanation, Feedback, Generation from .views2 import ContentMatches >>>>>>> 4a5b54df (ContentMatch refactor):ansible_wisdom/ai/api/urls.py +======= +from .views import Attributions, Completions, Explanation, Feedback +from .views2 import ContentMatches +from .views3 import Generation +>>>>>>> 75363e32 (Generations refactor):ansible_wisdom/ai/api/urls.py urlpatterns = [ path("completions/", Completions.as_view(), name="completions"), diff --git a/ansible_ai_connect/ai/api/views2.py b/ansible_ai_connect/ai/api/views2.py index 341cf009a..436d89fa0 100644 --- a/ansible_ai_connect/ai/api/views2.py +++ b/ansible_ai_connect/ai/api/views2.py @@ -115,7 +115,7 @@ def perform_content_matching( ): _model_id = model_id - @call("suggestions", suggestion_id) + @call("suggestions", lambda: suggestion_id) def get_content_matches() -> ContentMatchResponseSerializer: __model_id = _model_id model_mesh_client = apps.get_app_config("ai").model_mesh_client diff --git a/ansible_ai_connect/ai/api/views3.py b/ansible_ai_connect/ai/api/views3.py new file mode 100644 index 000000000..2ada798f6 --- /dev/null +++ b/ansible_ai_connect/ai/api/views3.py @@ -0,0 +1,165 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from string import Template + +from ansible_anonymizer import anonymizer +from django.apps import apps +from drf_spectacular.utils import OpenApiResponse, extend_schema +from rest_framework import status as rest_framework_status +from rest_framework.response import Response +from rest_framework.views import APIView + +from ansible_ai_connect.ai.api.api_wrapper import call +from ansible_ai_connect.ai.api.aws.exceptions import WcaSecretManagerError +from ansible_ai_connect.ai.api.model_client.exceptions import ( + WcaModelIdNotFound, + WcaNoDefaultModelId, +) + +from .permissions import ( + AcceptedTermsPermission, + BlockUserWithoutSeat, + BlockUserWithoutSeatAndWCAReadyOrg, + BlockUserWithSeatButWCANotReady, +) +from .serializers import GenerationRequestSerializer, GenerationResponseSerializer +from .utils.segment import send_segment_event + +logger = logging.getLogger(__name__) + + +class Generation(APIView): + """ + Returns a playbook based on a text input. + """ + + from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope + from rest_framework import permissions + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + AcceptedTermsPermission, + BlockUserWithoutSeat, + BlockUserWithoutSeatAndWCAReadyOrg, + BlockUserWithSeatButWCANotReady, + ] + required_scopes = ["read", "write"] + + throttle_cache_key_suffix = "_generation" + + @extend_schema( + request=GenerationRequestSerializer, + responses={ + 200: GenerationResponseSerializer, + 204: OpenApiResponse(description="Empty response"), + 400: OpenApiResponse(description="Bad Request"), + 401: OpenApiResponse(description="Unauthorized"), + 429: OpenApiResponse(description="Request was throttled"), + 503: OpenApiResponse(description="Service Unavailable"), + }, + summary="Inline code suggestions", + ) + def post(self, request) -> Response: + + # This isn't ideal... but I need generation_id for loggin in the decorator + def generation_id_provider(): + request_serializer = GenerationRequestSerializer(data=request.data) + request_serializer.is_valid(raise_exception=False) + return str(request_serializer.data.get("generationId", "")) + + @call("generation", generation_id_provider) + def get_generation() -> Response: + exception = None + wizard_id = None + duration = None + create_outline = None + anonymized_playbook = "" + request_serializer = GenerationRequestSerializer(data=request.data) + + try: + generation_id = generation_id_provider() + request_serializer.is_valid(raise_exception=True) + create_outline = request_serializer.validated_data["createOutline"] + outline = str(request_serializer.validated_data.get("outline", "")) + text = request_serializer.validated_data["text"] + wizard_id = str(request_serializer.validated_data.get("wizardId", "")) + + llm = apps.get_app_config("ai").model_mesh_client + start_time = time.time() + playbook, outline = llm.generate_playbook(request, text, create_outline, outline) + duration = round((time.time() - start_time) * 1000, 2) + + # Anonymize responses + # Anonymized in the View to be consistent with where Completions are anonymized + anonymized_playbook = anonymizer.anonymize_struct( + playbook, value_template=Template("{{ _${variable_name}_ }}") + ) + anonymized_outline = anonymizer.anonymize_struct( + outline, value_template=Template("{{ _${variable_name}_ }}") + ) + + answer = { + "playbook": anonymized_playbook, + "outline": anonymized_outline, + "format": "plaintext", + "generationId": generation_id, + } + + finally: + self.write_to_segment( + request.user, + generation_id, + wizard_id, + exception, + duration, + create_outline, + playbook_length=len(anonymized_playbook), + ) + + return Response( + answer, + status=rest_framework_status.HTTP_200_OK, + ) + + return get_generation() + + def write_to_segment( + self, user, generation_id, wizard_id, exception, duration, create_outline, playbook_length + ): + model_name = "" + try: + model_mesh_client = apps.get_app_config("ai").model_mesh_client + model_name = model_mesh_client.get_model_id(user.org_id, "") + except (WcaNoDefaultModelId, WcaModelIdNotFound, WcaSecretManagerError): + pass + event = { + "create_outline": create_outline, + "duration": duration, + "exception": exception is not None, + "generationId": generation_id, + "modelName": model_name, + "playbook_length": playbook_length, + "wizardId": wizard_id, + } + if exception: + event["response"] = ( + { + "exception": str(exception), + }, + ) + send_segment_event(event, "codegenPlaybook", user) From fea9416fa7c6629a0014600bdc4701c07a217644 Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Fri, 21 Jun 2024 12:38:45 +0100 Subject: [PATCH 4/6] Explanation refactor --- ansible_ai_connect/ai/api/api_wrapper.py | 12 +- .../ai/api/model_client/wca_client.py | 3 + ansible_ai_connect/ai/api/tests/test_views.py | 5 +- ansible_ai_connect/ai/api/urls.py | 7 + ansible_ai_connect/ai/api/views3.py | 6 +- ansible_ai_connect/ai/api/views4.py | 156 ++++++++++++++++++ 6 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 ansible_ai_connect/ai/api/views4.py diff --git a/ansible_ai_connect/ai/api/api_wrapper.py b/ansible_ai_connect/ai/api/api_wrapper.py index 006ed9ce5..4a50346d6 100644 --- a/ansible_ai_connect/ai/api/api_wrapper.py +++ b/ansible_ai_connect/ai/api/api_wrapper.py @@ -21,6 +21,7 @@ from rest_framework.exceptions import ValidationError from ansible_ai_connect.ai.api.exceptions import ( + FeatureNotAvailable, ModelTimeoutException, ServiceUnavailable, WcaBadRequestException, @@ -118,8 +119,17 @@ def wrapper(*args, **kwargs): ) raise + except FeatureNotAvailable: + logger.exception( + f"The requested feature is unavailable for {api_type}: {identifier}" + ) + raise + except Exception as e: - logger.exception(f"error requesting completion for {api_type}: {identifier}") + logger.exception( + f"An unhandled exception {e.__class__} occurred " + f"during processing of {api_type}: {identifier}" + ) raise ServiceUnavailable(cause=e) return wrapper diff --git a/ansible_ai_connect/ai/api/model_client/wca_client.py b/ansible_ai_connect/ai/api/model_client/wca_client.py index 73191a4ca..3642dc1ca 100644 --- a/ansible_ai_connect/ai/api/model_client/wca_client.py +++ b/ansible_ai_connect/ai/api/model_client/wca_client.py @@ -520,7 +520,10 @@ def explain_playbook(self, request, content: str) -> str: headers=headers, json=data, ) + context = Context(model_id, result, False) + InferenceResponseChecks().run_checks(context) result.raise_for_status() + response = json.loads(result.text) return response["explanation"] diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index 5e34b043b..a7cb83907 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -2592,9 +2592,8 @@ def test_service_unavailable(self, invoke): } self.client.force_authenticate(user=self.user) - with self.assertRaises(Exception): - r = self.client.post(reverse("explanations"), payload, format="json") - self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE) + r = self.client.post(reverse("explanations"), payload, format="json") + self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE) @override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="dummy") diff --git a/ansible_ai_connect/ai/api/urls.py b/ansible_ai_connect/ai/api/urls.py index bcb080eaa..d3c12e8a1 100644 --- a/ansible_ai_connect/ai/api/urls.py +++ b/ansible_ai_connect/ai/api/urls.py @@ -14,6 +14,7 @@ from django.urls import path +<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py <<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py <<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py from .views import Completions, ContentMatches, Explanation, Feedback, Generation @@ -26,6 +27,12 @@ from .views2 import ContentMatches from .views3 import Generation >>>>>>> 75363e32 (Generations refactor):ansible_wisdom/ai/api/urls.py +======= +from .views import Attributions, Completions, Feedback +from .views2 import ContentMatches +from .views3 import Generation +from .views4 import Explanation +>>>>>>> 38ad467f (Explanation refactor):ansible_wisdom/ai/api/urls.py urlpatterns = [ path("completions/", Completions.as_view(), name="completions"), diff --git a/ansible_ai_connect/ai/api/views3.py b/ansible_ai_connect/ai/api/views3.py index 2ada798f6..8878b644d 100644 --- a/ansible_ai_connect/ai/api/views3.py +++ b/ansible_ai_connect/ai/api/views3.py @@ -76,7 +76,7 @@ class Generation(APIView): ) def post(self, request) -> Response: - # This isn't ideal... but I need generation_id for loggin in the decorator + # This isn't ideal... but I need generation_id for logging in the decorator def generation_id_provider(): request_serializer = GenerationRequestSerializer(data=request.data) request_serializer.is_valid(raise_exception=False) @@ -120,6 +120,10 @@ def get_generation() -> Response: "generationId": generation_id, } + except Exception as exc: + exception = exc + raise + finally: self.write_to_segment( request.user, diff --git a/ansible_ai_connect/ai/api/views4.py b/ansible_ai_connect/ai/api/views4.py new file mode 100644 index 000000000..02c90ca73 --- /dev/null +++ b/ansible_ai_connect/ai/api/views4.py @@ -0,0 +1,156 @@ +# Copyright Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from string import Template + +from ansible_anonymizer import anonymizer +from django.apps import apps +from drf_spectacular.utils import OpenApiResponse, extend_schema +from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope +from rest_framework import permissions +from rest_framework import status as rest_framework_status +from rest_framework.response import Response +from rest_framework.views import APIView + +from ansible_ai_connect.ai.api.api_wrapper import call +from ansible_ai_connect.ai.api.aws.exceptions import WcaSecretManagerError +from ansible_ai_connect.ai.api.model_client.exceptions import ( + WcaModelIdNotFound, + WcaNoDefaultModelId, +) + +from .permissions import ( + AcceptedTermsPermission, + BlockUserWithoutSeat, + BlockUserWithoutSeatAndWCAReadyOrg, + BlockUserWithSeatButWCANotReady, +) +from .serializers import ExplanationRequestSerializer, ExplanationResponseSerializer +from .utils.segment import send_segment_event + +logger = logging.getLogger(__name__) + + +class Explanation(APIView): + """ + Returns a text that explains a playbook. + """ + + permission_classes = [ + permissions.IsAuthenticated, + IsAuthenticatedOrTokenHasScope, + AcceptedTermsPermission, + BlockUserWithoutSeat, + BlockUserWithoutSeatAndWCAReadyOrg, + BlockUserWithSeatButWCANotReady, + ] + required_scopes = ["read", "write"] + + throttle_cache_key_suffix = "_explanation" + + @extend_schema( + request=ExplanationRequestSerializer, + responses={ + 200: ExplanationResponseSerializer, + 204: OpenApiResponse(description="Empty response"), + 400: OpenApiResponse(description="Bad Request"), + 401: OpenApiResponse(description="Unauthorized"), + 429: OpenApiResponse(description="Request was throttled"), + 503: OpenApiResponse(description="Service Unavailable"), + }, + summary="Inline code suggestions", + ) + def post(self, request) -> Response: + + # This isn't ideal... but I need explanation_id for logging in the decorator + def explanation_id_provider(): + request_serializer = ExplanationRequestSerializer(data=request.data) + request_serializer.is_valid(raise_exception=False) + return str(request_serializer.data.get("explanationId", "")) + + @call("explanation", explanation_id_provider) + def get_explanation() -> Response: + duration = None + exception = None + explanation_id = None + playbook = "" + request_serializer = ExplanationRequestSerializer(data=request.data) + + try: + request_serializer.is_valid(raise_exception=True) + explanation_id = explanation_id_provider() + playbook = request_serializer.validated_data.get("content") + + llm = apps.get_app_config("ai").model_mesh_client + start_time = time.time() + explanation = llm.explain_playbook(request, playbook) + duration = round((time.time() - start_time) * 1000, 2) + + # Anonymize response + # Anonymized in the View to be consistent with where Completions are anonymized + anonymized_explanation = anonymizer.anonymize_struct( + explanation, value_template=Template("{{ _${variable_name}_ }}") + ) + + answer = { + "content": anonymized_explanation, + "format": "markdown", + "explanationId": explanation_id, + } + + except Exception as exc: + exception = exc + raise + + finally: + self.write_to_segment( + request.user, + explanation_id, + exception, + duration, + playbook_length=len(playbook), + ) + + return Response( + answer, + status=rest_framework_status.HTTP_200_OK, + ) + + return get_explanation() + + def write_to_segment(self, user, explanation_id, exception, duration, playbook_length): + model_name = "" + try: + model_mesh_client = apps.get_app_config("ai").model_mesh_client + model_name = model_mesh_client.get_model_id(user.org_id, "") + except (WcaNoDefaultModelId, WcaModelIdNotFound, WcaSecretManagerError): + pass + + event = { + "duration": duration, + "exception": exception is not None, + "explanationId": explanation_id, + "modelName": model_name, + "playbook_length": playbook_length, + "rh_user_org_id": user.org_id, + } + if exception: + event["response"] = ( + { + "exception": str(exception), + }, + ) + send_segment_event(event, "explainPlaybook", user) From 773806bb24f9af88df7ded46b6949c67f6f77416 Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Fri, 21 Jun 2024 12:54:41 +0100 Subject: [PATCH 5/6] Fix tests. --- ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py b/ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py index fdc35182d..450b0dfbf 100644 --- a/ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py +++ b/ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py @@ -260,6 +260,7 @@ def setUp(self): wca_client.session = Mock() response = Mock response.text = '{"playbook": "Oh!", "outline": "Ahh!", "explanation": "!Óh¡"}' + response.status_code = 200 response.raise_for_status = Mock() wca_client.session.post.return_value = response self.wca_client = wca_client From c46411fb186bba2486e0f3a179137e42b0567db1 Mon Sep 17 00:00:00 2001 From: Michael Anstis <manstis@redhat.com> Date: Wed, 26 Jun 2024 11:02:45 +0100 Subject: [PATCH 6/6] Rebase with main --- ansible_ai_connect/ai/api/tests/test_views.py | 70 ------------------- ansible_ai_connect/ai/api/urls.py | 17 +---- ansible_ai_connect/ai/api/views2.py | 64 ++--------------- 3 files changed, 5 insertions(+), 146 deletions(-) diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index a7cb83907..aca0476ec 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -2186,76 +2186,6 @@ def setUp(self): self.model_client.get_token = Mock(return_value={"access_token": "abc"}) self.model_client.get_api_key = Mock(return_value="org-api-key") -<<<<<<< HEAD:ansible_ai_connect/ai/api/tests/test_views.py -======= - self.search_response = { - "attributions": [ - { - "repo_name": repo_name, - "repo_url": repo_url, - "path": path, - "license": license, - "data_source": DataSource.GALAXY_R, - "ansible_type": AnsibleType.UNKNOWN, - "score": 0.0, - }, - ], - "meta": { - "encode_duration": 1000, - "search_duration": 2000, - }, - } - - @override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True) - @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") - @patch("ansible_ai_connect.ai.api.views2.send_segment_event") - @patch("ansible_ai_connect.ai.search.search") - def test_wca_contentmatch_segment_events_with_unseated_user( - self, mock_search, mock_send_segment_event - ): - self.user.rh_user_has_seat = False - - mock_search.return_value = self.search_response - - r = self.client.post(reverse("contentmatches"), self.payload) - self.assertEqual(r.status_code, HTTPStatus.OK) - - event = { - "exception": False, - "modelName": "", - "problem": None, - "response": { - "contentmatches": [ - { - "contentmatch": [ - { - "repo_name": "robertdebock.nginx", - "repo_url": "https://galaxy.ansible.com/robertdebock/nginx", - "path": "tasks/main.yml", - "license": "apache-2.0", - "score": 0.0, - "data_source_description": "Ansible Galaxy roles", - } - ] - } - ] - }, - "metadata": [{"encode_duration": 1000, "search_duration": 2000}], - } - - event_request = { - "suggestions": [ - "\n - name: install nginx on RHEL\n become: true\n " - "ansible.builtin.package:\n name: nginx\n state: present\n" - ] - } - - actual_event = mock_send_segment_event.call_args_list[0][0][0] - - self.assertTrue(event.items() <= actual_event.items()) - self.assertTrue(event_request.items() <= actual_event.get("request").items()) - ->>>>>>> 4a5b54df (ContentMatch refactor):ansible_wisdom/ai/api/tests/test_views.py @override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") @patch("ansible_ai_connect.ai.api.views2.send_segment_event") def test_wca_contentmatch_segment_events_with_seated_user(self, mock_send_segment_event): diff --git a/ansible_ai_connect/ai/api/urls.py b/ansible_ai_connect/ai/api/urls.py index d3c12e8a1..ec32acc3c 100644 --- a/ansible_ai_connect/ai/api/urls.py +++ b/ansible_ai_connect/ai/api/urls.py @@ -14,25 +14,10 @@ from django.urls import path -<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py -<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py -<<<<<<< HEAD:ansible_ai_connect/ai/api/urls.py -from .views import Completions, ContentMatches, Explanation, Feedback, Generation -======= -from .views import Attributions, Completions, Explanation, Feedback, Generation -from .views2 import ContentMatches ->>>>>>> 4a5b54df (ContentMatch refactor):ansible_wisdom/ai/api/urls.py -======= -from .views import Attributions, Completions, Explanation, Feedback -from .views2 import ContentMatches -from .views3 import Generation ->>>>>>> 75363e32 (Generations refactor):ansible_wisdom/ai/api/urls.py -======= -from .views import Attributions, Completions, Feedback +from .views import Completions, Feedback from .views2 import ContentMatches from .views3 import Generation from .views4 import Explanation ->>>>>>> 38ad467f (Explanation refactor):ansible_wisdom/ai/api/urls.py urlpatterns = [ path("completions/", Completions.as_view(), name="completions"), diff --git a/ansible_ai_connect/ai/api/views2.py b/ansible_ai_connect/ai/api/views2.py index 436d89fa0..0ca027437 100644 --- a/ansible_ai_connect/ai/api/views2.py +++ b/ansible_ai_connect/ai/api/views2.py @@ -14,7 +14,6 @@ import logging import time -from http import HTTPStatus from django.apps import apps from django.conf import settings @@ -35,12 +34,7 @@ ) from ansible_ai_connect.users.models import User -from .. import search as ai_search -from .data.data_model import ( - AttributionsResponseDto, - ContentMatchPayloadData, - ContentMatchResponseDto, -) +from .data.data_model import ContentMatchPayloadData, ContentMatchResponseDto from .permissions import AcceptedTermsPermission, BlockUserWithoutSeat, IsAAPLicensed from .serializers import ContentMatchRequestSerializer, ContentMatchResponseSerializer from .utils.segment import send_segment_event @@ -95,12 +89,9 @@ def post(self, request) -> Response: model_id = str(request_data.get("model", "")) try: - if request.user.rh_user_has_seat: - response_serializer = self.perform_content_matching( - model_id, suggestion_id, request.user, request_data - ) - else: - response_serializer = self.perform_search(request_data, request.user) + response_serializer = self.perform_content_matching( + model_id, suggestion_id, request.user, request_data + ) return Response(response_serializer.data, status=rest_framework_status.HTTP_200_OK) except Exception: logger.exception("Error requesting content matches") @@ -204,53 +195,6 @@ def get_content_matches() -> ContentMatchResponseSerializer: return get_content_matches() - def perform_search(self, request_data, user: User): - suggestion_id = str(request_data.get("suggestionId", "")) - response_serializer = None - - exception = None - start_time = time.time() - metadata = [] - model_name = "" - - try: - suggestion = request_data["suggestions"][0] - response_item = ai_search.search(suggestion) - - attributions_dto = AttributionsResponseDto(**response_item) - response_data = {"contentmatches": []} - response_data["contentmatches"].append(attributions_dto.content_matches) - metadata.append(attributions_dto.meta) - - try: - response_serializer = ContentMatchResponseSerializer(data=response_data) - response_serializer.is_valid(raise_exception=True) - except Exception: - process_error_count.labels(stage="attr-response_serialization_validation").inc() - logger.exception(f"Error serializing final response for suggestion {suggestion_id}") - raise InternalServerError - - except Exception as e: - exception = e - logger.exception("Failed to search for attributions for content matching") - return Response( - {"message": "Unable to complete the request"}, status=HTTPStatus.SERVICE_UNAVAILABLE - ) - finally: - duration = round((time.time() - start_time) * 1000, 2) - self.write_to_segment( - request_data, - duration, - exception, - metadata, - model_name, - response_serializer.data if response_serializer else {}, - suggestion_id, - user, - ) - - return response_serializer - def write_to_segment( self, request_data,