Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api/view: improve the collect of the Schema1 events #1147

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/code_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:

- name: Running Unit Tests (Python)
run: |
coverage run --rcfile=setup.cfg -m ansible_ai_connect.manage test ansible_ai_connect
coverage run --rcfile=setup.cfg -m ansible_ai_connect.manage test ansible_ai_connect --failfast
coverage xml
coverage report --rcfile=setup.cfg --format=markdown > code-coverage-results.md

Expand Down
1 change: 0 additions & 1 deletion ansible_ai_connect/ai/api/data/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class APIPayload(BaseModel):
model: str = ""
prompt: str = ""
original_prompt: str = ""
context: str = ""
userId: Optional[UUID] = None
suggestionId: Optional[UUID] = None
Expand Down
7 changes: 0 additions & 7 deletions ansible_ai_connect/ai/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ class InternalServerError(BaseWisdomAPIException):
default_detail = "An error occurred attempting to complete the request."


class FeedbackValidationException(WisdomBadRequest):
default_code = "error__feedback_validation"

def __init__(self, detail, *args, **kwargs):
super().__init__(detail, *args, **kwargs)


class FeedbackInternalServerException(BaseWisdomAPIException):
status_code = 500
default_code = "error__feedback_internal_server"
Expand Down
19 changes: 0 additions & 19 deletions ansible_ai_connect/ai/api/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,25 +368,6 @@ def get_task_names_from_tasks(tasks):
return names


def restore_original_task_names(output_yaml, prompt):
if output_yaml and is_multi_task_prompt(prompt):
prompt_tasks = get_task_names_from_prompt(prompt)
matches = re.finditer(r"^- name:\s+(.*)", output_yaml, re.M)
for i, match in enumerate(matches):
try:
task_line = match.group(0)
task = match.group(1)
restored_task_line = task_line.replace(task, prompt_tasks[i])
output_yaml = output_yaml.replace(task_line, restored_task_line)
except IndexError:
logger.error(
"There is no match for the enumerated prompt task in the suggestion yaml"
)
break

return output_yaml


# List of Task keywords to filter out during prediction results parsing.
ansible_task_keywords = None
# RegExp Pattern based on ARI sources, see ansible_risk_insight/finder.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
class DeserializeStage(PipelineElement):
def process(self, context: CompletionContext) -> None:
request = context.request
# NOTE: This line is probably useless
request._request._suggestion_id = request.data.get("suggestionId")

request_serializer = CompletionRequestSerializer(
data=request.data, context={"request": request}
)

try:
# TODO: is_valid() is already called in ai/api/views.py and we should
# reuse the validated_data here
request_serializer.is_valid(raise_exception=True)
request._request._suggestion_id = str(request_serializer.validated_data["suggestionId"])
request._request._ansible_extension_version = str(
Expand All @@ -62,6 +65,5 @@ def process(self, context: CompletionContext) -> None:
)

payload = APIPayload(**request_serializer.validated_data)
payload.original_prompt = request.data.get("prompt", "")

context.payload = payload
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,6 @@ def process(self, context: CompletionContext) -> None:
event_name = "trialExpired"
raise WcaUserTrialExpiredException(cause=e)

except Exception as e:
exception = e
logger.exception(f"error requesting completion for suggestion {suggestion_id}")
raise ServiceUnavailable(cause=e)

finally:
duration = round((time.time() - start_time) * 1000, 2)
completions_hist.observe(duration / 1000) # millisec back to seconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from prometheus_client import Histogram
from yaml.error import MarkedYAMLError

import ansible_ai_connect.ai.api.telemetry.schema1 as schema1
from ansible_ai_connect.ai.api import formatter as fmtr
from ansible_ai_connect.ai.api.exceptions import (
PostprocessException,
process_error_count,
)
from ansible_ai_connect.ai.api.pipelines.common import PipelineElement
from ansible_ai_connect.ai.api.pipelines.completion_context import CompletionContext
from ansible_ai_connect.ai.api.utils.segment import send_segment_event
from ansible_ai_connect.ai.api.utils.segment import send_schema1_event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,32 +79,28 @@ def write_to_segment(
if isinstance(exception, MarkedYAMLError)
else str(exception) if str(exception) else exception.__class__.__name__
)

if event_type == "ARI":
event_name = "postprocess"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"truncated": truncated_yaml,
"postprocessed": postprocessed_yaml,
"details": postprocess_detail,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
if event_type == "ansible-lint":
event_name = "postprocessLint"
event = {
"exception": exception is not None,
"problem": problem,
"duration": duration,
"recommendation": recommendation_yaml,
"postprocessed": postprocessed_yaml,
"suggestionId": str(suggestion_id) if suggestion_id else None,
}
schema1_event = schema1.Postprocess()
schema1_event.details = postprocess_detail
schema1_event.truncated = truncated_yaml

elif event_type == "ansible-lint":

schema1_event = schema1.PostprocessLint()
schema1_event.postprocessed = postprocessed_yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line needed? You do the same on L#97


schema1_event.set_user(user)
schema1_event.set_exception(exception)
schema1_event.duration = duration
schema1_event.postprocessed = postprocessed_yaml
schema1_event.problem = problem
schema1_event.recommendation = recommendation_yaml
schema1_event.suggestionId = str(suggestion_id) if suggestion_id else ""

if model_id:
event["modelName"] = model_id
send_segment_event(event, event_name, user)
schema1_event.modelName = model_id
send_schema1_event(schema1_event)


def trim_whitespace_lines(input: str):
Expand Down Expand Up @@ -139,11 +136,10 @@ def completion_post_process(context: CompletionContext):
model_id = context.model_id
suggestion_id = context.payload.suggestionId
prompt = context.payload.prompt
original_prompt = context.payload.original_prompt
payload_context = context.payload.context
original_indent = context.original_indent
post_processed_predictions = context.anonymized_predictions.copy()
is_multi_task_prompt = fmtr.is_multi_task_prompt(original_prompt)
is_multi_task_prompt = fmtr.is_multi_task_prompt(prompt)

ari_caller = apps.get_app_config("ai").get_ari_caller()
if not ari_caller:
Expand All @@ -159,16 +155,11 @@ def completion_post_process(context: CompletionContext):
f"unexpected predictions array length {len(post_processed_predictions['predictions'])}"
)

anonymized_recommendation_yaml = post_processed_predictions["predictions"][0]
recommendation_yaml = post_processed_predictions["predictions"][0]

if not anonymized_recommendation_yaml:
raise PostprocessException(
f"unexpected prediction content {anonymized_recommendation_yaml}"
)
if not recommendation_yaml:
raise PostprocessException(f"unexpected prediction content {recommendation_yaml}")

recommendation_yaml = fmtr.restore_original_task_names(
anonymized_recommendation_yaml, original_prompt
)
recommendation_problem = None
truncated_yaml = None
postprocessed_yaml = None
Expand Down Expand Up @@ -213,7 +204,7 @@ def completion_post_process(context: CompletionContext):
f"original recommendation: \n{recommendation_yaml}"
)
postprocessed_yaml, ari_results = ari_caller.postprocess(
recommendation_yaml, original_prompt, payload_context
recommendation_yaml, prompt, payload_context
)
logger.debug(
f"suggestion id: {suggestion_id}, "
Expand Down Expand Up @@ -259,7 +250,7 @@ def completion_post_process(context: CompletionContext):
f"rules_with_applied_changes: {tasks_with_applied_changes} "
f"recommendation_yaml: [{repr(recommendation_yaml)}] "
f"postprocessed_yaml: [{repr(postprocessed_yaml)}] "
f"original_prompt: [{repr(original_prompt)}] "
f"prompt: [{repr(prompt)}] "
f"payload_context: [{repr(payload_context)}] "
f"postprocess_details: [{json.dumps(postprocess_details)}] "
)
Expand All @@ -279,7 +270,7 @@ def completion_post_process(context: CompletionContext):
write_to_segment(
user,
suggestion_id,
anonymized_recommendation_yaml,
recommendation_yaml,
truncated_yaml,
postprocessed_yaml,
postprocess_details,
Expand All @@ -298,9 +289,8 @@ def completion_post_process(context: CompletionContext):
input_yaml = postprocessed_yaml if postprocessed_yaml else recommendation_yaml
# Single task predictions are missing the `- name: ` line and fail linter schema check
if not is_multi_task_prompt:
input_yaml = (
f"{original_prompt.lstrip() if ari_caller else original_prompt}{input_yaml}"
)
prompt += "\n"
input_yaml = f"{prompt.lstrip() if ari_caller else prompt}{input_yaml}"
postprocessed_yaml = ansible_lint_caller.run_linter(input_yaml)
# Stripping the leading STRIP_YAML_LINE that was added by above processing
if postprocessed_yaml.startswith(STRIP_YAML_LINE):
Expand All @@ -318,7 +308,7 @@ def completion_post_process(context: CompletionContext):
)
finally:
anonymized_input_yaml = (
postprocessed_yaml if postprocessed_yaml else anonymized_recommendation_yaml
postprocessed_yaml if postprocessed_yaml else recommendation_yaml
)
write_to_segment(
user,
Expand Down Expand Up @@ -358,6 +348,8 @@ def completion_post_process(context: CompletionContext):
logger.debug(f"suggestion id: {suggestion_id}, indented recommendation: \n{indented_yaml}")

# gather data for completion segment event
# WARNING: the block below do inplace transformation of 'tasks', we should refact the
# code to avoid that.
for i, task in enumerate(tasks):
if fmtr.is_multi_task_prompt(prompt):
task["prediction"] = fmtr.extract_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

def completion_pre_process(context: CompletionContext):
prompt = context.payload.prompt
original_prompt, _ = fmtr.extract_prompt_and_context(context.payload.original_prompt)
payload_context = context.payload.context

# Additional context (variables) is supported when
Expand Down Expand Up @@ -70,17 +69,6 @@ def completion_pre_process(context: CompletionContext):
context.payload.context, context.payload.prompt = fmtr.preprocess(
payload_context, prompt, ansibleFileType, additionalContext
)
if not multi_task:
# We are currently more forgiving on leading spacing of single task
# prompts than multi task prompts. In order to use the "original"
# single task prompt successfull in post-processing, we need to
# ensure its spacing aligns with the normalized context we got
# back from preprocess. We can calculate the proper spacing from the
# normalized prompt.
normalized_indent = len(context.payload.prompt) - len(context.payload.prompt.lstrip())
normalized_original_prompt = fmtr.normalize_yaml(original_prompt)
original_prompt = " " * normalized_indent + normalized_original_prompt
context.payload.original_prompt = original_prompt


class PreProcessStage(PipelineElement):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def process(self, context: CompletionContext) -> None:
# Note: Currently we return an array of predictions, but there's only ever one.
# The tasks array added to the completion event is representative of the first (only)
# entry in the predictions array
# https://github.com/ansible/ansible-ai-connect-service/blob/0e083a83fab57e6567197697bad60d306c6e06eb/ansible_ai_connect/ai/api/pipelines/completion_stages/response.py#L64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment meant to be here?

response.tasks = tasks_results

context.response = response
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def add_indents(vars, n):
@modify_settings()
class CompletionPreProcessTest(TestCase):
def call_completion_pre_process(self, payload, is_commercial_user, expected_context):
original_prompt = payload.get("prompt")
user = Mock(rh_user_has_seat=is_commercial_user)
request = Mock(user=user)
serializer = CompletionRequestSerializer(context={"request": request})
Expand All @@ -529,7 +528,6 @@ def call_completion_pre_process(self, payload, is_commercial_user, expected_cont
request=request,
payload=APIPayload(
prompt=data.get("prompt"),
original_prompt=original_prompt,
context=data.get("context"),
),
metadata=data.get("metadata"),
Expand Down
5 changes: 1 addition & 4 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@


class Metadata(serializers.Serializer):
class Meta:
fields = ["ansibleExtensionVersion"]

ansibleExtensionVersion = serializers.RegexField(
r"v?\d+\.\d+\.\d+",
required=False,
Expand Down Expand Up @@ -118,7 +115,6 @@ def validate_extracted_prompt(prompt, user):
raise serializers.ValidationError(
{"prompt": "requested prompt format is not supported"}
)

if "&&" in prompt:
raise serializers.ValidationError(
{"prompt": "multiple task requests should be separated by a single '&'"}
Expand Down Expand Up @@ -155,6 +151,7 @@ def validate(self, data):
data = super().validate(data)

data["prompt"], data["context"] = fmtr.extract_prompt_and_context(data["prompt"])

CompletionRequestSerializer.validate_extracted_prompt(
data["prompt"], self.context.get("request").user
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from rest_framework.response import Response
from rest_framework.status import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST

from ansible_ai_connect.ai.api.exceptions import InternalServerError, ServiceUnavailable
from ansible_ai_connect.ai.api.permissions import (
IsOrganisationAdministrator,
IsOrganisationLightspeedSubscriber,
)
from ansible_ai_connect.ai.api.serializers import TelemetrySettingsRequestSerializer
from ansible_ai_connect.ai.api.utils.segment import send_segment_event
from ansible_ai_connect.ai.api.views import InternalServerError, ServiceUnavailable
from ansible_ai_connect.users.signals import user_set_telemetry_settings

logger = logging.getLogger(__name__)
Expand Down
Loading
Loading