Skip to content

Commit

Permalink
revert AAP-17357
Browse files Browse the repository at this point in the history
AAP-17357 is hard to maintain, only cover a subset of what AAP-24049 will do and finally complexify
the refactoring of the Schema1 event managment, See:
  #1147)
  • Loading branch information
goneri committed Jul 10, 2024
1 parent c6bf2fd commit 379665a
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 186 deletions.
1 change: 0 additions & 1 deletion ansible_ai_connect/ai/api/data/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class APIPayload(BaseModel):
model: str = ""
prompt: str = ""
original_prompt: str = ""
context: str = ""
userId: Optional[UUID] = None
suggestionId: Optional[UUID] = None
Expand Down
19 changes: 0 additions & 19 deletions ansible_ai_connect/ai/api/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,25 +368,6 @@ def get_task_names_from_tasks(tasks):
return names


def restore_original_task_names(output_yaml, prompt):
if output_yaml and is_multi_task_prompt(prompt):
prompt_tasks = get_task_names_from_prompt(prompt)
matches = re.finditer(r"^- name:\s+(.*)", output_yaml, re.M)
for i, match in enumerate(matches):
try:
task_line = match.group(0)
task = match.group(1)
restored_task_line = task_line.replace(task, prompt_tasks[i])
output_yaml = output_yaml.replace(task_line, restored_task_line)
except IndexError:
logger.error(
"There is no match for the enumerated prompt task in the suggestion yaml"
)
break

return output_yaml


# List of Task keywords to filter out during prediction results parsing.
ansible_task_keywords = None
# RegExp Pattern based on ARI sources, see ansible_risk_insight/finder.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,5 @@ def process(self, context: CompletionContext) -> None:
)

payload = APIPayload(**request_serializer.validated_data)
payload.original_prompt = request.data.get("prompt", "")

context.payload = payload
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,10 @@ def completion_post_process(context: CompletionContext):
model_id = context.model_id
suggestion_id = context.payload.suggestionId
prompt = context.payload.prompt
original_prompt = context.payload.original_prompt
payload_context = context.payload.context
original_indent = context.original_indent
post_processed_predictions = context.anonymized_predictions.copy()
is_multi_task_prompt = fmtr.is_multi_task_prompt(original_prompt)
is_multi_task_prompt = fmtr.is_multi_task_prompt(prompt)

ari_caller = apps.get_app_config("ai").get_ari_caller()
if not ari_caller:
Expand All @@ -159,16 +158,11 @@ def completion_post_process(context: CompletionContext):
f"unexpected predictions array length {len(post_processed_predictions['predictions'])}"
)

anonymized_recommendation_yaml = post_processed_predictions["predictions"][0]
recommendation_yaml = post_processed_predictions["predictions"][0]

if not anonymized_recommendation_yaml:
raise PostprocessException(
f"unexpected prediction content {anonymized_recommendation_yaml}"
)
if not recommendation_yaml:
raise PostprocessException(f"unexpected prediction content {recommendation_yaml}")

recommendation_yaml = fmtr.restore_original_task_names(
anonymized_recommendation_yaml, original_prompt
)
recommendation_problem = None
truncated_yaml = None
postprocessed_yaml = None
Expand Down Expand Up @@ -213,7 +207,7 @@ def completion_post_process(context: CompletionContext):
f"original recommendation: \n{recommendation_yaml}"
)
postprocessed_yaml, ari_results = ari_caller.postprocess(
recommendation_yaml, original_prompt, payload_context
recommendation_yaml, prompt, payload_context
)
logger.debug(
f"suggestion id: {suggestion_id}, "
Expand Down Expand Up @@ -259,7 +253,6 @@ def completion_post_process(context: CompletionContext):
f"rules_with_applied_changes: {tasks_with_applied_changes} "
f"recommendation_yaml: [{repr(recommendation_yaml)}] "
f"postprocessed_yaml: [{repr(postprocessed_yaml)}] "
f"original_prompt: [{repr(original_prompt)}] "
f"payload_context: [{repr(payload_context)}] "
f"postprocess_details: [{json.dumps(postprocess_details)}] "
)
Expand All @@ -279,7 +272,7 @@ def completion_post_process(context: CompletionContext):
write_to_segment(
user,
suggestion_id,
anonymized_recommendation_yaml,
recommendation_yaml,
truncated_yaml,
postprocessed_yaml,
postprocess_details,
Expand All @@ -298,9 +291,8 @@ def completion_post_process(context: CompletionContext):
input_yaml = postprocessed_yaml if postprocessed_yaml else recommendation_yaml
# Single task predictions are missing the `- name: ` line and fail linter schema check
if not is_multi_task_prompt:
input_yaml = (
f"{original_prompt.lstrip() if ari_caller else original_prompt}{input_yaml}"
)
prompt += "\n"
input_yaml = f"{prompt.lstrip() if ari_caller else prompt}{input_yaml}"
postprocessed_yaml = ansible_lint_caller.run_linter(input_yaml)
# Stripping the leading STRIP_YAML_LINE that was added by above processing
if postprocessed_yaml.startswith(STRIP_YAML_LINE):
Expand All @@ -318,7 +310,7 @@ def completion_post_process(context: CompletionContext):
)
finally:
anonymized_input_yaml = (
postprocessed_yaml if postprocessed_yaml else anonymized_recommendation_yaml
postprocessed_yaml if postprocessed_yaml else recommendation_yaml
)
write_to_segment(
user,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

def completion_pre_process(context: CompletionContext):
prompt = context.payload.prompt
original_prompt, _ = fmtr.extract_prompt_and_context(context.payload.original_prompt)
payload_context = context.payload.context

# Additional context (variables) is supported when
Expand Down Expand Up @@ -70,17 +69,6 @@ def completion_pre_process(context: CompletionContext):
context.payload.context, context.payload.prompt = fmtr.preprocess(
payload_context, prompt, ansibleFileType, additionalContext
)
if not multi_task:
# We are currently more forgiving on leading spacing of single task
# prompts than multi task prompts. In order to use the "original"
# single task prompt successfull in post-processing, we need to
# ensure its spacing aligns with the normalized context we got
# back from preprocess. We can calculate the proper spacing from the
# normalized prompt.
normalized_indent = len(context.payload.prompt) - len(context.payload.prompt.lstrip())
normalized_original_prompt = fmtr.normalize_yaml(original_prompt)
original_prompt = " " * normalized_indent + normalized_original_prompt
context.payload.original_prompt = original_prompt


class PreProcessStage(PipelineElement):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def add_indents(vars, n):
@modify_settings()
class CompletionPreProcessTest(TestCase):
def call_completion_pre_process(self, payload, is_commercial_user, expected_context):
original_prompt = payload.get("prompt")
user = Mock(rh_user_has_seat=is_commercial_user)
request = Mock(user=user)
serializer = CompletionRequestSerializer(context={"request": request})
Expand All @@ -529,7 +528,6 @@ def call_completion_pre_process(self, payload, is_commercial_user, expected_cont
request=request,
payload=APIPayload(
prompt=data.get("prompt"),
original_prompt=original_prompt,
context=data.get("context"),
),
metadata=data.get("metadata"),
Expand Down
130 changes: 0 additions & 130 deletions ansible_ai_connect/ai/api/tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,135 +319,6 @@ def test_insert_set_fact_task(self):
self.assertTrue("ansible.builtin.set_fact" in data[0])
self.assertEqual(data[0]["ansible.builtin.set_fact"], merged_vars)

def test_restore_original_task_names(self):
single_task_prompt = "- name: Install ssh\n"
multi_task_prompt = "# Install Apache & say hello [email protected]\n"
multi_task_prompt_with_loop = (
"# Delete all virtual machines in my Azure resource group called 'melisa' that "
"exists longer than 24 hours. Do not delete virtual machines that exists less "
"than 24 hours."
)
multi_task_prompt_with_loop_extra_task = (
"# Delete all virtual machines in my Azure resource group "
"& say hello to [email protected]"
)

multi_task_yaml = (
"- name: Install Apache\n ansible.builtin.apt:\n "
"name: apache2\n state: latest\n- name: say hello [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)
multi_task_yaml_extra_task = (
"- name: Install Apache\n ansible.builtin.apt:\n "
"name: apache2\n state: latest\n- name: say hello [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]"
"\n- name: say hi [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)
multi_task_yaml_with_loop = (
"- name: Delete all virtual machines in my "
"Azure resource group called 'test' that exists longer than 24 hours. Do not "
"delete virtual machines that exists less than 24 hours.\n"
" azure.azcollection.azure_rm_virtualmachine:\n"
' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n'
" vm_size: Standard_A0\n"
' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n'
' password: "{{ _password_ }}"\n'
' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n'
)
multi_task_yaml_with_loop_extra_task = (
"- name: Delete all virtual machines in my Azure resource group\n"
" azure.azcollection.azure_rm_virtualmachine:\n"
' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n'
" vm_size: Standard_A0\n"
' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n'
' password: "{{ _password_ }}"\n'
' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n'
"- name: say hello to [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)
single_task_yaml = (
" ansible.builtin.package:\n name: openssh-server\n state: present\n when:\n"
" - enable_ssh | bool\n - ansible_distribution == 'Ubuntu'"
)
expected_multi_task_yaml = (
"- name: Install Apache\n ansible.builtin.apt:\n "
"name: apache2\n state: latest\n- name: say hello [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)
expected_multi_task_yaml_with_loop = (
"- name: Delete all virtual machines in my "
"Azure resource group called 'melisa' that exists longer than 24 hours. Do not "
"delete virtual machines that exists less than 24 hours.\n"
" azure.azcollection.azure_rm_virtualmachine:\n"
' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n'
" vm_size: Standard_A0\n"
' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n'
' password: "{{ _password_ }}"\n'
' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n'
)
expected_multi_task_yaml_with_loop_extra_task = (
"- name: Delete all virtual machines in my Azure resource group\n"
" azure.azcollection.azure_rm_virtualmachine:\n"
' name: "{{ _name_ }}"\n state: absent\n resource_group: myResourceGroup\n'
" vm_size: Standard_A0\n"
' image: "{{ _image_ }}"\n loop:\n - name: "{{ vm_name }}"\n'
' password: "{{ _password_ }}"\n'
' user: "{{ vm_user }}"\n location: "{{ vm_location }}"\n'
"- name: say hello to [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)

self.assertEqual(
expected_multi_task_yaml,
fmtr.restore_original_task_names(multi_task_yaml, multi_task_prompt),
)

with self.assertLogs(logger="root", level="ERROR") as log:
fmtr.restore_original_task_names(multi_task_yaml_extra_task, multi_task_prompt),
self.assertInLog(
"There is no match for the enumerated prompt task in the suggestion yaml", log
)

self.assertEqual(
expected_multi_task_yaml_with_loop,
fmtr.restore_original_task_names(
multi_task_yaml_with_loop, multi_task_prompt_with_loop
),
)

self.assertEqual(
expected_multi_task_yaml_with_loop_extra_task,
fmtr.restore_original_task_names(
multi_task_yaml_with_loop_extra_task, multi_task_prompt_with_loop_extra_task
),
)

self.assertEqual(
single_task_yaml,
fmtr.restore_original_task_names(single_task_yaml, single_task_prompt),
)

self.assertEqual(
"",
fmtr.restore_original_task_names("", multi_task_prompt),
)

def test_restore_original_task_names_for_index_error(self):
# The following prompt simulates a mismatch between requested tasks and received tasks
multi_task_prompt = "# Install Apache\n"
multi_task_yaml = (
"- name: Install Apache\n ansible.builtin.apt:\n "
"name: apache2\n state: latest\n- name: say hello [email protected]\n "
"ansible.builtin.debug:\n msg: Hello there [email protected]\n"
)

with self.assertLogs(logger="root", level="ERROR") as log:
fmtr.restore_original_task_names(multi_task_yaml, multi_task_prompt)
self.assertInLog(
"There is no match for the enumerated prompt task in the suggestion yaml", log
)

def test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_multi(self):
prompt = " # install ffmpeg"
self.assertEqual(prompt, fmtr.strip_task_preamble_from_multi_task_prompt(prompt))
Expand Down Expand Up @@ -731,7 +602,6 @@ def test_get_fqcn_module_from_prediction_with_task_keywords(self):
tests.test_get_task_names_multi()
tests.test_load_and_merge_vars_in_context()
tests.test_insert_set_fact_task()
tests.test_restore_original_task_names()
tests.test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_multi()
tests.test_strip_task_preamble_from_multi_task_prompt_no_preamble_unchanged_single()
tests.test_strip_task_preamble_from_multi_task_prompt_one_preamble_changed()
Expand Down
5 changes: 1 addition & 4 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ def __init__(
request = Mock(user=user)
serializer = CompletionRequestSerializer(context={"request": request})
data = serializer.validate(payload.copy())

api_payload = APIPayload(prompt=data.get("prompt"), context=data.get("context"))
api_payload.original_prompt = payload["prompt"]

context = CompletionContext(
request=request,
payload=api_payload,
Expand Down Expand Up @@ -861,7 +858,7 @@ def test_multi_task_prompt_commercial_with_pii(self):
r = self.client.post(reverse("completions"), payload)
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["predictions"])
self.assertIn(pii_task.capitalize(), r.data["predictions"][0])
self.assertNotIn(pii_task.capitalize(), r.data["predictions"][0])
self.assertSegmentTimestamp(log)
segment_events = self.extractSegmentEventsFromLog(log)
self.assertTrue(len(segment_events) > 0)
Expand Down

0 comments on commit 379665a

Please sign in to comment.