Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AAP-36732: ModelPipelines: Configuration improvements: Remove env vars #1487

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Set up python3
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aligns the Python version with the other GHA definition.

3.11 is needed to support ExceptionGroup (used in this PR). It is also the version used to build the container.


- name: Install dependencies
run: |
Expand Down
40 changes: 40 additions & 0 deletions README-ANSIBLE_AI_MODEL_MESH_CONFIG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Example `ANSIBLE_AI_MODEL_MESH_CONFIG` configuration

Pay close attention to the formatting of the blocks.

Each ends with `}},` otherwise conversion of the multi-line setting to a `str` can fail.
manstis marked this conversation as resolved.
Show resolved Hide resolved

```text
ANSIBLE_AI_MODEL_MESH_CONFIG="{
"ModelPipelineCompletions": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineContentMatch": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelinePlaybookGeneration": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineRoleGeneration": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelinePlaybookExplanation": {
"provider": "ollama",
"config": {
"inference_url": "http://host.containers.internal:11434",
"model_id": "mistral:instruct"}},
"ModelPipelineChatBot": {
"provider": "http",
"config": {
"inference_url": "http://localhost:8000",
"model_id": "granite3-8b"}}
}"
```
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ SECRET_KEY="somesecretvalue"
ENABLE_ARI_POSTPROCESS="False"
WCA_SECRET_BACKEND_TYPE="dummy"
# configure model server
ANSIBLE_AI_MODEL_MESH_API_URL="http://host.containers.internal:11434"
ANSIBLE_AI_MODEL_MESH_API_TYPE="ollama"
ANSIBLE_AI_MODEL_MESH_MODEL_ID="mistral:instruct"
ANSIBLE_AI_MODEL_MESH_CONFIG="..."
```
See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md).

### Start service and dependencies

Expand Down Expand Up @@ -108,9 +107,9 @@ command line the variable `DEBUG=True`.

The Django service listens on <http://127.0.0.1:8000>.

Note that there is no pytorch service defined in the docker-compose
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_API_URL`
configuration key to point on an existing service.
Note that there is no pytorch service defined in the `docker-compose`
file. You should adjust the `ANSIBLE_AI_MODEL_MESH_CONFIG`
configuration to point to an existing service.

## <a name="aws-config">Use the WCA API Keys Manager</a>

Expand Down Expand Up @@ -460,11 +459,10 @@ To connect to the Mistal 7b Instruct model running on locally on [llama.cpp](htt
```
1. Set the appropriate environment variables
```bash
ANSIBLE_AI_MODEL_MESH_API_URL=http://$YOUR_REAL_IP:8080
ANSIBLE_AI_MODEL_MESH_API_TYPE=llamacpp
ANSIBLE_AI_MODEL_MESH_MODEL_ID=mistral-7b-instruct-v0.2.Q5_K_M.gguf
ANSIBLE_AI_MODEL_MESH_CONFIG="..."
ENABLE_ARI_POSTPROCESS=False
```
See the example [ANSIBLE_AI_MODEL_MESH_CONFIG](./docs/config/examples/README-ANSIBLE_AI_MODEL_MESH_CONFIG.md).

# Testing

Expand Down
43 changes: 41 additions & 2 deletions ansible_ai_connect/ai/api/model_pipelines/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from json import JSONDecodeError

import yaml
from django.conf import settings
from yaml import YAMLError

from ansible_ai_connect.ai.api.model_pipelines.config_providers import Configuration
from ansible_ai_connect.ai.api.model_pipelines.config_serializers import (
ConfigurationSerializer,
)

logger = logging.getLogger(__name__)


def load_config() -> Configuration:
source = json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG)
serializer = ConfigurationSerializer(data=source)
# yaml.safe_load(..) seems to also support loading JSON. Nice.
# However, try to load JSON with the correct _loader_ first in case of corner cases
errors: [Exception] = []
result = load_json()
if isinstance(result, Exception):
errors.append(result)
result = load_yaml()
if isinstance(result, Exception):
errors.append(result)
else:
errors = []

if len(errors) > 0:
raise ExceptionGroup("Unable to parse ANSIBLE_AI_MODEL_MESH_CONFIG", errors)

serializer = ConfigurationSerializer(data=result)
serializer.is_valid(raise_exception=True)
serializer.save()
return serializer.instance


def load_json() -> str | Exception:
try:
logger.info("Attempting to parse ANSIBLE_AI_MODEL_MESH_CONFIG as JSON...")
return json.loads(settings.ANSIBLE_AI_MODEL_MESH_CONFIG)
except JSONDecodeError as e:
logger.exception(f"An error occurring parsing ANSIBLE_AI_MODEL_MESH_CONFIG as JSON:\n{e}")
return e


def load_yaml() -> str | Exception:
try:
logger.info("Attempting to parse ANSIBLE_AI_MODEL_MESH_CONFIG as YAML...")
y = yaml.safe_load(settings.ANSIBLE_AI_MODEL_MESH_CONFIG)
return y
except YAMLError as e:
logger.exception(f"An error occurring parsing ANSIBLE_AI_MODEL_MESH_CONFIG as YAML:\n{e}")
return e
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class PipelineConfigurationSerializer(serializers.Serializer):

def to_internal_value(self, data):
provider_part = super().to_internal_value(data)
serializer = REGISTRY[provider_part["provider"]][Serializer](data=data["config"])
config_part = data["config"] if "config" in data else {}
serializer = REGISTRY[provider_part["provider"]][Serializer](data=config_part)
serializer.is_valid(raise_exception=True)

return {**provider_part, "config": serializer.validated_data}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, config: DummyConfiguration):
super().__init__(config=config)

def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
logger.debug("!!!! settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == 'dummy' !!!!")
logger.debug("!!!! ModelPipelineCompletions.provider == 'dummy' !!!!")
logger.debug("!!!! Mocking Model response !!!!")
if self.config.latency_use_jitter:
jitter: float = secrets.randbelow(1000) * 0.001
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def __init__(self, **kwargs):

@Register(api_type="http")
class HttpConfigurationSerializer(BaseConfigSerializer):
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def __init__(self, **kwargs):

@Register(api_type="llamacpp")
class LlamaCppConfigurationSerializer(BaseConfigSerializer):
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from json import JSONDecodeError

import yaml
from django.test import override_settings
from rest_framework.exceptions import ValidationError
from yaml import YAMLError

from ansible_ai_connect.ai.api.model_pipelines.config_loader import load_config
from ansible_ai_connect.ai.api.model_pipelines.config_providers import Configuration
from ansible_ai_connect.ai.api.model_pipelines.pipelines import MetaData
from ansible_ai_connect.ai.api.model_pipelines.registry import REGISTRY_ENTRY
from ansible_ai_connect.ai.api.model_pipelines.tests import mock_config
from ansible_ai_connect.test_utils import WisdomTestCase

EMPTY = {
"MetaData": {
"provider": "dummy",
},
}


def _convert_json_to_yaml(json_config: str):
yaml_config = yaml.safe_load(json_config)
return yaml.safe_dump(yaml_config)


class TestConfigLoader(WisdomTestCase):

def assert_config(self):
config: Configuration = load_config()
pipelines = [i for i in REGISTRY_ENTRY.keys() if issubclass(i, MetaData)]
for k in pipelines:
self.assertTrue(k.__name__ in config)

def assert_invalid_config(self):
with self.assertRaises(ExceptionGroup) as e:
load_config()
exceptions = e.exception.exceptions
self.assertEqual(len(exceptions), 2)
self.assertIsInstance(exceptions[0], JSONDecodeError)
self.assertIsInstance(exceptions[1], YAMLError)

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=None)
def test_config_undefined(self):
with self.assertRaises(TypeError):
load_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=json.dumps(EMPTY))
def test_config_empty(self):
self.assert_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG="")
def test_config_empty_string(self):
with self.assertRaises(ValidationError):
self.assert_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG='{"MetaData" : {')
def test_config_invalid_json(self):
self.assert_invalid_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG="MetaData:\nbanana")
def test_config_invalid_yaml(self):
self.assert_invalid_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=mock_config("ollama"))
def test_config_json(self):
self.assert_config()

@override_settings(ANSIBLE_AI_MODEL_MESH_CONFIG=_convert_json_to_yaml(mock_config("ollama")))
def test_config_yaml(self):
self.assert_config()
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,6 @@ def test_codematch_empty_response(self):
self.assertEqual(e.exception.model_id, model_id)


@override_settings(ANSIBLE_AI_MODEL_MESH_MODEL_ID=None)
class TestDummySecretManager(TestCase):
def setUp(self):
super().setUp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, provider: t_model_mesh_api_type, config: WCABaseConfiguration

class WCABaseConfigurationSerializer(BaseConfigSerializer):
api_key = serializers.CharField(required=False, allow_null=True, allow_blank=True)
verify_ssl = serializers.BooleanField(required=False, default=False)
verify_ssl = serializers.BooleanField(required=False, default=True)
retry_count = serializers.IntegerField(required=False, default=4)
enable_ari_postprocessing = serializers.BooleanField(required=False, default=False)
health_check_api_key = serializers.CharField(required=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def __init__(self, config: WCAOnPremConfiguration):
raise WcaUsernameNotFound
if not self.config.api_key:
raise WcaKeyNotFound
# ANSIBLE_AI_MODEL_MESH_MODEL_ID cannot be validated until runtime. The
# User may provide an override value if the Environment Variable is not set.
# WCAOnPremConfiguration.model_id cannot be validated until runtime. The
# User may provide an override value if the setting is not defined.

def get_request_headers(
self, api_key: str, identifier: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_api_key(self, user, organization_id: Optional[int]) -> str:

if organization_id is None:
logger.error(
"User does not have an organization and no ANSIBLE_AI_MODEL_MESH_API_KEY is set"
"User does not have an organization and WCASaaSConfiguration.api_key is not set"
)
raise WcaKeyNotFound

Expand Down
2 changes: 1 addition & 1 deletion ansible_ai_connect/ai/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ class IsWCASaaSModelPipeline(permissions.BasePermission):
message = "User doesn't have access to the IBM watsonx Code Assistant."

def has_permission(self, request, view):
return CONTINUE if settings.ANSIBLE_AI_MODEL_MESH_API_TYPE == "wca" else BLOCK
return CONTINUE if settings.DEPLOYMENT_MODE == "saas" else BLOCK
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from ansible_anonymizer import anonymizer
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
from prometheus_client import Histogram

Expand Down Expand Up @@ -124,7 +123,7 @@ def process(self, context: CompletionContext) -> None:
except ModelTimeoutError as e:
exception = e
logger.warning(
f"model timed out after {settings.ANSIBLE_AI_MODEL_MESH_API_TIMEOUT} "
f"model timed out after {model_mesh_client.config.timeout} "
f"seconds (per task) for suggestion {suggestion_id}"
)
raise ModelTimeoutException(cause=e)
Expand Down
4 changes: 2 additions & 2 deletions ansible_ai_connect/ai/api/tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ def test_ensure_trial_user_can_pass_through_despite_trial_disabled(self):

class TestBlockUserWithoutWCASaaSConfiguration(WisdomAppsBackendMocking):

@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca")
@override_settings(DEPLOYMENT_MODE="saas")
def test_wca_saas_enabled(self):
self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), CONTINUE)

@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="wca-onprem")
@override_settings(DEPLOYMENT_MODE="onprem")
def test_wca_saas_not_enabled(self):
self.assertEqual(IsWCASaaSModelPipeline().has_permission(Mock(), None), BLOCK)
Loading
Loading