diff --git a/ansible_ai_connect/ai/api/data/data_model.py b/ansible_ai_connect/ai/api/data/data_model.py index 68eae79e2..a72682dda 100644 --- a/ansible_ai_connect/ai/api/data/data_model.py +++ b/ansible_ai_connect/ai/api/data/data_model.py @@ -28,7 +28,6 @@ class APIPayload(BaseModel): model: str = "" prompt: str = "" - original_prompt: str = "" context: str = "" userId: Optional[UUID] = None suggestionId: Optional[UUID] = None diff --git a/ansible_ai_connect/ai/api/formatter.py b/ansible_ai_connect/ai/api/formatter.py index 847d9c7a6..475812943 100644 --- a/ansible_ai_connect/ai/api/formatter.py +++ b/ansible_ai_connect/ai/api/formatter.py @@ -368,25 +368,6 @@ def get_task_names_from_tasks(tasks): return names -def restore_original_task_names(output_yaml, prompt): - if output_yaml and is_multi_task_prompt(prompt): - prompt_tasks = get_task_names_from_prompt(prompt) - matches = re.finditer(r"^- name:\s+(.*)", output_yaml, re.M) - for i, match in enumerate(matches): - try: - task_line = match.group(0) - task = match.group(1) - restored_task_line = task_line.replace(task, prompt_tasks[i]) - output_yaml = output_yaml.replace(task_line, restored_task_line) - except IndexError: - logger.error( - "There is no match for the enumerated prompt task in the suggestion yaml" - ) - break - - return output_yaml - - # List of Task keywords to filter out during prediction results parsing. ansible_task_keywords = None # RegExp Pattern based on ARI sources, see ansible_risk_insight/finder.py diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/deserialise.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/deserialise.py index bb03efaa1..675fb9228 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/deserialise.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/deserialise.py @@ -62,6 +62,5 @@ def process(self, context: CompletionContext) -> None: ) payload = APIPayload(**request_serializer.validated_data) - payload.original_prompt = request.data.get("prompt", "") context.payload = payload diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/post_process.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/post_process.py index d253f1d0b..245a52e12 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/post_process.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/post_process.py @@ -139,11 +139,10 @@ def completion_post_process(context: CompletionContext): model_id = context.model_id suggestion_id = context.payload.suggestionId prompt = context.payload.prompt - original_prompt = context.payload.original_prompt payload_context = context.payload.context original_indent = context.original_indent post_processed_predictions = context.anonymized_predictions.copy() - is_multi_task_prompt = fmtr.is_multi_task_prompt(original_prompt) + is_multi_task_prompt = fmtr.is_multi_task_prompt(prompt) ari_caller = apps.get_app_config("ai").get_ari_caller() if not ari_caller: @@ -159,16 +158,11 @@ def completion_post_process(context: CompletionContext): f"unexpected predictions array length {len(post_processed_predictions['predictions'])}" ) - anonymized_recommendation_yaml = post_processed_predictions["predictions"][0] + recommendation_yaml = post_processed_predictions["predictions"][0] - if not anonymized_recommendation_yaml: - raise PostprocessException( - f"unexpected prediction content {anonymized_recommendation_yaml}" - ) + if not recommendation_yaml: + raise PostprocessException(f"unexpected prediction content {recommendation_yaml}") - recommendation_yaml = fmtr.restore_original_task_names( - anonymized_recommendation_yaml, original_prompt - ) recommendation_problem = None truncated_yaml = None postprocessed_yaml = None @@ -213,7 +207,7 @@ def completion_post_process(context: CompletionContext): f"original recommendation: \n{recommendation_yaml}" ) postprocessed_yaml, ari_results = ari_caller.postprocess( - recommendation_yaml, original_prompt, payload_context + recommendation_yaml, prompt, payload_context ) logger.debug( f"suggestion id: {suggestion_id}, " @@ -259,7 +253,6 @@ def completion_post_process(context: CompletionContext): f"rules_with_applied_changes: {tasks_with_applied_changes} " f"recommendation_yaml: [{repr(recommendation_yaml)}] " f"postprocessed_yaml: [{repr(postprocessed_yaml)}] " - f"original_prompt: [{repr(original_prompt)}] " f"payload_context: [{repr(payload_context)}] " f"postprocess_details: [{json.dumps(postprocess_details)}] " ) @@ -279,7 +272,7 @@ def completion_post_process(context: CompletionContext): write_to_segment( user, suggestion_id, - anonymized_recommendation_yaml, + recommendation_yaml, truncated_yaml, postprocessed_yaml, postprocess_details, @@ -298,9 +291,8 @@ def completion_post_process(context: CompletionContext): input_yaml = postprocessed_yaml if postprocessed_yaml else recommendation_yaml # Single task predictions are missing the `- name: ` line and fail linter schema check if not is_multi_task_prompt: - input_yaml = ( - f"{original_prompt.lstrip() if ari_caller else original_prompt}{input_yaml}" - ) + prompt += "\n" + input_yaml = f"{prompt.lstrip() if ari_caller else prompt}{input_yaml}" postprocessed_yaml = ansible_lint_caller.run_linter(input_yaml) # Stripping the leading STRIP_YAML_LINE that was added by above processing if postprocessed_yaml.startswith(STRIP_YAML_LINE): @@ -318,7 +310,7 @@ def completion_post_process(context: CompletionContext): ) finally: anonymized_input_yaml = ( - postprocessed_yaml if postprocessed_yaml else anonymized_recommendation_yaml + postprocessed_yaml if postprocessed_yaml else recommendation_yaml ) write_to_segment( user, diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/pre_process.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/pre_process.py index 461c17598..c720e9942 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/pre_process.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/pre_process.py @@ -38,7 +38,6 @@ def completion_pre_process(context: CompletionContext): prompt = context.payload.prompt - original_prompt, _ = fmtr.extract_prompt_and_context(context.payload.original_prompt) payload_context = context.payload.context # Additional context (variables) is supported when @@ -70,17 +69,6 @@ def completion_pre_process(context: CompletionContext): context.payload.context, context.payload.prompt = fmtr.preprocess( payload_context, prompt, ansibleFileType, additionalContext ) - if not multi_task: - # We are currently more forgiving on leading spacing of single task - # prompts than multi task prompts. In order to use the "original" - # single task prompt successfull in post-processing, we need to - # ensure its spacing aligns with the normalized context we got - # back from preprocess. We can calculate the proper spacing from the - # normalized prompt. - normalized_indent = len(context.payload.prompt) - len(context.payload.prompt.lstrip()) - normalized_original_prompt = fmtr.normalize_yaml(original_prompt) - original_prompt = " " * normalized_indent + normalized_original_prompt - context.payload.original_prompt = original_prompt class PreProcessStage(PipelineElement): diff --git a/ansible_ai_connect/ai/api/pipelines/completion_stages/tests/test_pre_process.py b/ansible_ai_connect/ai/api/pipelines/completion_stages/tests/test_pre_process.py index 2a7b9c092..5d31a40f6 100644 --- a/ansible_ai_connect/ai/api/pipelines/completion_stages/tests/test_pre_process.py +++ b/ansible_ai_connect/ai/api/pipelines/completion_stages/tests/test_pre_process.py @@ -520,7 +520,6 @@ def add_indents(vars, n): @modify_settings() class CompletionPreProcessTest(TestCase): def call_completion_pre_process(self, payload, is_commercial_user, expected_context): - original_prompt = payload.get("prompt") user = Mock(rh_user_has_seat=is_commercial_user) request = Mock(user=user) serializer = CompletionRequestSerializer(context={"request": request}) @@ -529,7 +528,6 @@ def call_completion_pre_process(self, payload, is_commercial_user, expected_cont request=request, payload=APIPayload( prompt=data.get("prompt"), - original_prompt=original_prompt, context=data.get("context"), ), metadata=data.get("metadata"), diff --git a/ansible_ai_connect/ai/api/tests/test_formatter.py b/ansible_ai_connect/ai/api/tests/test_formatter.py index 9c2dd231f..4e0bb65d5 100644 --- a/ansible_ai_connect/ai/api/tests/test_formatter.py +++ b/ansible_ai_connect/ai/api/tests/test_formatter.py @@ -319,135 +319,6 @@ def test_insert_set_fact_task(self): self.assertTrue("ansible.builtin.set_fact" in data[0]) self.assertEqual(data[0]["ansible.builtin.set_fact"], merged_vars) - def test_restore_original_task_names(self): - single_task_prompt = "- name: Install ssh\n" - multi_task_prompt = "# Install Apache & say hello fred@redhat.com\n" - multi_task_prompt_with_loop = ( - "# Delete all virtual machines in my Azure resource group called 'melisa' that " - "exists longer than 24 hours. Do not delete virtual machines that exists less " - "than 24 hours." - ) - multi_task_prompt_with_loop_extra_task = ( - "# Delete all virtual machines in my Azure resource group " - "& say hello to ada@anemail.com" - ) - - multi_task_yaml = ( - "- name: Install Apache\n ansible.builtin.apt:\n " - "name: apache2\n state: latest\n- name: say hello test@example.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - multi_task_yaml_extra_task = ( - "- name: Install Apache\n ansible.builtin.apt:\n " - "name: apache2\n state: latest\n- name: say hello test@example.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com" - "\n- name: say hi test@example.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - multi_task_yaml_with_loop = ( - "- name: Delete all virtual machines in my " - "Azure resource group called 'test' that exists longer than 24 hours. Do not " - "delete virtual machines that exists less than 24 hours.\n" - " azure.azcollection.azure_rm_virtualmachine:\n" - ' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n' - " vm_size: Standard_A0\n" - ' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n' - ' password: "{{ _password_ }}"\n' - ' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n' - ) - multi_task_yaml_with_loop_extra_task = ( - "- name: Delete all virtual machines in my Azure resource group\n" - " azure.azcollection.azure_rm_virtualmachine:\n" - ' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n' - " vm_size: Standard_A0\n" - ' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n' - ' password: "{{ _password_ }}"\n' - ' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n' - "- name: say hello to ada@anemail.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - single_task_yaml = ( - " ansible.builtin.package:\n name: openssh-server\n state: present\n when:\n" - " - enable_ssh | bool\n - ansible_distribution == 'Ubuntu'" - ) - expected_multi_task_yaml = ( - "- name: Install Apache\n ansible.builtin.apt:\n " - "name: apache2\n state: latest\n- name: say hello fred@redhat.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - expected_multi_task_yaml_with_loop = ( - "- name: Delete all virtual machines in my " - "Azure resource group called 'melisa' that exists longer than 24 hours. Do not " - "delete virtual machines that exists less than 24 hours.\n" - " azure.azcollection.azure_rm_virtualmachine:\n" - ' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n' - " vm_size: Standard_A0\n" - ' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n' - ' password: "{{ _password_ }}"\n' - ' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n' - ) - expected_multi_task_yaml_with_loop_extra_task = ( - "- name: Delete all virtual machines in my Azure resource group\n" - " azure.azcollection.azure_rm_virtualmachine:\n" - ' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n' - " vm_size: Standard_A0\n" - ' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n' - ' password: "{{ _password_ }}"\n' - ' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n' - "- name: say hello to ada@anemail.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - - self.assertEqual( - expected_multi_task_yaml, - fmtr.restore_original_task_names(multi_task_yaml, multi_task_prompt), - ) - - with self.assertLogs(logger="root", level="ERROR") as log: - fmtr.restore_original_task_names(multi_task_yaml_extra_task, multi_task_prompt), - self.assertInLog( - "There is no match for the enumerated prompt task in the suggestion yaml", log - ) - - self.assertEqual( - expected_multi_task_yaml_with_loop, - fmtr.restore_original_task_names( - multi_task_yaml_with_loop, multi_task_prompt_with_loop - ), - ) - - self.assertEqual( - expected_multi_task_yaml_with_loop_extra_task, - fmtr.restore_original_task_names( - multi_task_yaml_with_loop_extra_task, multi_task_prompt_with_loop_extra_task - ), - ) - - self.assertEqual( - single_task_yaml, - fmtr.restore_original_task_names(single_task_yaml, single_task_prompt), - ) - - self.assertEqual( - "", - fmtr.restore_original_task_names("", multi_task_prompt), - ) - - def test_restore_original_task_names_for_index_error(self): - # The following prompt simulates a mismatch between requested tasks and received tasks - multi_task_prompt = "# Install Apache\n" - multi_task_yaml = ( - "- name: Install Apache\n ansible.builtin.apt:\n " - "name: apache2\n state: latest\n- name: say hello test@example.com\n " - "ansible.builtin.debug:\n msg: Hello there olivia1@example.com\n" - ) - - with self.assertLogs(logger="root", level="ERROR") as log: - fmtr.restore_original_task_names(multi_task_yaml, multi_task_prompt) - self.assertInLog( - "There is no match for the enumerated prompt task in the suggestion yaml", log - ) - def test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_multi(self): prompt = " # install ffmpeg" self.assertEqual(prompt, fmtr.strip_task_preamble_from_multi_task_prompt(prompt)) @@ -731,7 +602,6 @@ def test_get_fqcn_module_from_prediction_with_task_keywords(self): tests.test_get_task_names_multi() tests.test_load_and_merge_vars_in_context() tests.test_insert_set_fact_task() - tests.test_restore_original_task_names() tests.test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_multi() tests.test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_single() tests.test_strip_task_preamble_from_multi_task_prompt_one_preamble_changed() diff --git a/ansible_ai_connect/ai/api/tests/test_views.py b/ansible_ai_connect/ai/api/tests/test_views.py index db0834cd4..fbb931d06 100644 --- a/ansible_ai_connect/ai/api/tests/test_views.py +++ b/ansible_ai_connect/ai/api/tests/test_views.py @@ -128,10 +128,7 @@ def __init__( request = Mock(user=user) serializer = CompletionRequestSerializer(context={"request": request}) data = serializer.validate(payload.copy()) - api_payload = APIPayload(prompt=data.get("prompt"), context=data.get("context")) - api_payload.original_prompt = payload["prompt"] - context = CompletionContext( request=request, payload=api_payload, @@ -861,7 +858,7 @@ def test_multi_task_prompt_commercial_with_pii(self): r = self.client.post(reverse("completions"), payload) self.assertEqual(r.status_code, HTTPStatus.OK) self.assertIsNotNone(r.data["predictions"]) - self.assertIn(pii_task.capitalize(), r.data["predictions"][0]) + self.assertNotIn(pii_task.capitalize(), r.data["predictions"][0]) self.assertSegmentTimestamp(log) segment_events = self.extractSegmentEventsFromLog(log) self.assertTrue(len(segment_events) > 0)