Skip to content

Commit

Permalink
AAP-36732: ModelPipelines: Configuration improvements: Remove env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
manstis committed Jan 10, 2025
1 parent d7c5ab0 commit 3051acc
Show file tree
Hide file tree
Showing 34 changed files with 482 additions and 333 deletions.
40 changes: 40 additions & 0 deletions README-ANSIBLE_AI_MODEL_MESH_CONFIG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration

Pay close attention to the formatting of the blocks.

Each ends with `}},` otherwise conversion of the multi-line setting to a `str` can fail.

```text
ANSIBLE_AI_MODEL_MESH_CONFIG="{
"ModelPipelineCompletions": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineContentMatch": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelinePlaybookGeneration": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineRoleGeneration": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelinePlaybookExplanation": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineChatBot": {
"provider": "http",
"config": {
"inference_url": "http://localhost:8000",
"model_id": "granite3-8b"}}
}"
```
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ SECRET_KEY="somesecretvalue"
ENABLE_ARI_POSTPROCESS="False"
WCA_SECRET_BACKEND_TYPE="dummy"
# configure model server
ANSIBLE_AI_MODEL_MESH_API_URL="http://host.containers.internal:11434"
ANSIBLE_AI_MODEL_MESH_API_TYPE="ollama"
ANSIBLE_AI_MODEL_MESH_MODEL_ID="mistral:instruct"
ANSIBLE_AI_MODEL_MESH_CONFIG="..."
```
See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md).

### Start service and dependencies

Expand Down Expand Up @@ -108,9 +107,9 @@ command line the variable `DEBUG=True`.

The Django service listens on <http://127.0.0.1:8000>.

Note that there is no pytorch service defined in the docker-compose
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_API_URL`
configuration key to point on an existing service.
Note that there is no pytorch service defined in the `docker-compose`
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_CONFIG`
configuration to point to an existing service.

## <a name="aws-config">Use the WCA API Keys Manager</a>

Expand Down Expand Up @@ -460,11 +459,10 @@ To connect to the Mistal 7b Instruct model running on locally on [llama.cpp](htt
```
1. Set the appropriate environment variables
```bash
ANSIBLE_AI_MODEL_MESH_API_URL=http://$YOUR_REAL_IP:8080
ANSIBLE_AI_MODEL_MESH_API_TYPE=llamacpp
ANSIBLE_AI_MODEL_MESH_MODEL_ID=mistral-7b-instruct-v0.2.Q5_K_M.gguf
ANSIBLE_AI_MODEL_MESH_CONFIG="..."
ENABLE_ARI_POSTPROCESS=False
```
See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md).

# Testing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, config: DummyConfiguration):
super().__init__(config=config)

def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
logger.debug("!!!! settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == 'dummy' !!!!")
logger.debug("!!!! ModelPipelineCompletions.provider == 'dummy' !!!!")
logger.debug("!!!! Mocking Model response !!!!")
if self.config.latency_use_jitter:
jitter: float = secrets.randbelow(1000) * 0.001
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def __init__(self, **kwargs):

@Register(api_type="http")
class HttpConfigurationSerializer(BaseConfigSerializer):
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def __init__(self, **kwargs):

@Register(api_type="llamacpp")
class LlamaCppConfigurationSerializer(BaseConfigSerializer):
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,6 @@ def test_codematch_empty_response(self):
self.assertEqual(e.exception.model_id, model_id)


@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
class TestDummySecretManager(TestCase):
def setUp(self):
super().setUp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, provider: t_model_mesh_api_type, config: WCABaseConfiguration

class WCABaseConfigurationSerializer(BaseConfigSerializer):
api_key = serializers.CharField(required=False, allow_null=True, allow_blank=True)
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
retry_count = serializers.IntegerField(required=False, default=4)
enable_ari_postprocessing = serializers.BooleanField(required=False, default=False)
health_check_api_key = serializers.CharField(required=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def __init__(self, config: WCAOnPremConfiguration):
raise WcaUsernameNotFound
if not self.config.api_key:
raise WcaKeyNotFound
# ANSIBLE_AI_MODEL_MESH_MODEL_ID cannot be validated until runtime. The
# User may provide an override value if the Environment Variable is not set.
# WCAOnPremConfiguration.model_id cannot be validated until runtime. The
# User may provide an override value if the setting is not defined.

def get_request_headers(
self, api_key: str, identifier: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_api_key(self, user, organization_id: Optional[int]) -> str:

if organization_id is None:
logger.error(
"User does not have an organization and no ANSIBLE_AI_MODEL_MESH_API_KEY is set"
"User does not have an organization and WCASaaSConfiguration.api_key is not set"
)
raise WcaKeyNotFound

Expand Down
2 changes: 1 addition & 1 deletion ansible_ai_connect/ai/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ class IsWCASaaSModelPipeline(permissions.BasePermission):
message = "User doesn't have access to the IBM watsonx Code Assistant."

def has_permission(self, request, view):
return CONTINUE if settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == "wca" else BLOCK
return CONTINUE if settings.DEPLOYMENT_MODE == "saas" else BLOCK
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from ansible_anonymizer import anonymizer
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
from prometheus_client import Histogram

Expand Down Expand Up @@ -124,7 +123,7 @@ def process(self, context: CompletionContext) -> None:
except ModelTimeoutError as e:
exception = e
logger.warning(
f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} "
f"model timed out after {model_mesh_client.config.timeout} "
f"seconds (per task) for suggestion {suggestion_id}"
)
raise ModelTimeoutException(cause=e)
Expand Down
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ def test_ensure_trial_user_can_pass_through_despite_trial_disabled(self):

class TestBlockUserWithoutWCASaaSConfiguration(WisdomAppsBackendMocking):

@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca")
@override_settings(DEPLOYMENT_MODE="saas")
def test_wca_saas_enabled(self):
self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), CONTINUE)

@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca-onprem")
@override_settings(DEPLOYMENT_MODE="onprem")
def test_wca_saas_not_enabled(self):
self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), BLOCK)
42 changes: 23 additions & 19 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import requests
from django.apps import apps
from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import modify_settings, override_settings
from django.urls import reverse
Expand Down Expand Up @@ -1018,7 +1017,7 @@ def test_full_payload(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1040,7 +1039,7 @@ def test_multi_task_prompt_commercial(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [
"- name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n- name: start Apache\n ansible.builtin.service:\n name: apache2\n state: started\n enabled: yes\n" # noqa: E501
],
Expand Down Expand Up @@ -1088,7 +1087,7 @@ def test_multi_task_prompt_commercial_with_pii(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [
" - name: Install Apache\n ansible.builtin.apt:\n name: apache2\n state: latest\n - name: say hello [email protected]\n ansible.builtin.debug:\n msg: Hello there [email protected]\n" # noqa: E501
],
Expand Down Expand Up @@ -1127,7 +1126,7 @@ def test_rate_limit(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1151,7 +1150,7 @@ def test_missing_prompt(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1172,7 +1171,7 @@ def test_authentication_error(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
# self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def test_completions_preprocessing_error(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1227,7 +1226,7 @@ def test_completions_preprocessing_error_without_name_prompt(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1251,7 +1250,7 @@ def test_full_payload_without_ARI(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand All @@ -1276,7 +1275,7 @@ def test_full_payload_with_recommendation_with_broken_last_line(self):
}
# quotation in the last line is not closed, but the truncate function can handle this.
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [
' ansible.builtin.apt:\n name: apache2\n register: "test'
],
Expand All @@ -1303,7 +1302,7 @@ def test_completions_postprocessing_error_for_invalid_yaml(self):
}
# this prediction has indentation problem with the prompt above
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n garbage name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1375,7 +1374,7 @@ def test_full_payload_without_ansible_lint_with_commercial_user(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [" ansible.builtin.apt:\n name: apache2"],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -1558,7 +1557,7 @@ def test_completions_pii_clean_up(self):
"suggestionId": str(uuid.uuid4()),
}
response_data = {
"model_id": settings.ANSIBLE_AI_MODEL_MESH_MODEL_ID,
"model_id": "a-model-id",
"predictions": [""],
}
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -4044,9 +4043,7 @@ def json(self):
json_response["response"] = input
return MockResponse(json_response, status_code)

@override_settings(CHATBOT_URL="http://localhost:8080")
@override_settings(CHATBOT_DEFAULT_PROVIDER="wisdom")
@override_settings(CHATBOT_DEFAULT_MODEL="granite-8b")
@mock.patch(
"requests.post",
side_effect=mocked_requests_post,
Expand All @@ -4058,7 +4055,6 @@ def query_with_no_error(self, payload, mock_post):
"requests.post",
side_effect=mocked_requests_post,
)
@override_settings(CHATBOT_URL="")
def query_without_chat_config(self, payload, mock_post):
return self.client.post(reverse("chat"), payload, format="json")

Expand Down Expand Up @@ -4178,7 +4174,11 @@ def test_operational_telemetry(self):
patch.object(
apps.get_app_config("ai"),
"get_model_pipeline",
Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))),
Mock(
return_value=HttpChatBotPipeline(
mock_pipeline_config("http", model_id="granite-8b")
)
),
),
self.assertLogs(logger="root", level="DEBUG") as log,
):
Expand Down Expand Up @@ -4285,7 +4285,11 @@ def test_operational_telemetry_with_system_prompt_override(self):
patch.object(
apps.get_app_config("ai"),
"get_model_pipeline",
Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))),
Mock(
return_value=HttpChatBotPipeline(
mock_pipeline_config("http", model_id="granite-8b")
)
),
),
self.assertLogs(logger="root", level="DEBUG") as log,
):
Expand Down
18 changes: 9 additions & 9 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def perform_content_matching(
except ModelTimeoutError as e:
exception = e
logger.warn(
f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} seconds"
f"model timed out after {model_mesh_client.config.timeout} seconds"
f" for suggestion {suggestion_id}"
)
raise ModelTimeoutException(cause=e)
Expand Down Expand Up @@ -1135,10 +1135,14 @@ class ChatEndpointThrottle(EndpointRateThrottle):
throttle_classes = [ChatEndpointThrottle]
schema1_event = schema1.ChatBotOperationalEvent

llm: ModelPipelineChatBot

def __init__(self):
self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot)

self.chatbot_enabled = (
settings.CHATBOT_URL
and settings.CHATBOT_DEFAULT_MODEL
self.llm.config.inference_url
and self.llm.config.model_id
and settings.CHATBOT_DEFAULT_PROVIDER
)
if self.chatbot_enabled:
Expand Down Expand Up @@ -1178,7 +1182,7 @@ def post(self, request) -> Response:
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
else settings.CHATBOT_DEFAULT_MODEL
else self.llm.config.model_id
)
req_provider = (
request_serializer.validated_data["provider"]
Expand All @@ -1199,11 +1203,7 @@ def post(self, request) -> Response:
self.event.conversation_id = conversation_id
self.event.modelName = req_model_id

llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
ModelPipelineChatBot
)

data = llm.invoke(
data = self.llm.invoke(
ChatBotParameters.init(
request=request,
query=req_query,
Expand Down
Loading

0 comments on commit 3051acc

Please sign in to comment.