Skip to content

Commit

Permalink
Merge pull request #249 from tigergraph/GML-1767-watsonx-integration
Browse files Browse the repository at this point in the history
Gml 1767 watsonx integration
  • Loading branch information
parkererickson-tg authored Jul 29, 2024
2 parents 3d9d255 + 3474ba2 commit a4e384c
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build-test-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
echo "$LLM_CONFIG_GROQ_MIXTRAL" > configs/groq_mixtral_config.json
echo "$LLM_CONFIG_BEDROCK_CLAUDE3" > configs/bedrock_config.json
echo "$LLM_CONFIG_HUGGINGFACE_LLAMA70B" > configs/huggingface_llama70b_config.json
echo "$LLM_CONFIG_WATSONX_MISTRAL_LARGE" > configs/ibm_watsonx_config.json
echo "$MILVUS_CONFIG" > configs/milvus_config.json
env:
DB_CONFIG: ${{ secrets.DB_CONFIG }}
Expand All @@ -57,6 +58,7 @@ jobs:
LLM_CONFIG_HUGGINGFACE_PHI3: ${{ secrets.LLM_CONFIG_HUGGINGFACE_PHI3 }}
LLM_CONFIG_OPENAI_GPT4O: ${{ secrets.LLM_CONFIG_OPENAI_GPT4O }}
LLM_CONFIG_HUGGINGFACE_LLAMA70B: ${{ secrets.LLM_CONFIG_HUGGINGFACE_LLAMA70B }}
LLM_CONFIG_WATSONX_MISTRAL_LARGE: ${{ secrets.LLM_CONFIG_WATSONX_MISTRAL_LARGE }}
GCP_CREDS_CONFIG: ${{ secrets.GCP_CREDS_CONFIG }}
LLM_TEST_EVALUATOR: ${{ secrets.LLM_TEST_EVALUATOR }}
MILVUS_CONFIG: ${{ secrets.MILVUS_CONFIG }}
Expand Down
5 changes: 4 additions & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
OpenAI,
Groq,
Ollama,
HuggingFaceEndpoint
HuggingFaceEndpoint,
IBMWatsonX
)
from common.session import SessionHandler
from common.status import StatusManager
Expand Down Expand Up @@ -121,6 +122,8 @@ def get_llm_service(llm_config):
return Ollama(llm_config["completion_service"])
elif llm_config["completion_service"]["llm_service"].lower() == "huggingface":
return HuggingFaceEndpoint(llm_config["completion_service"])
elif llm_config["completion_service"]["llm_service"].lower() == "watsonx":
return IBMWatsonX(llm_config["completion_service"])
else:
raise Exception("LLM Completion Service Not Supported")

Expand Down
1 change: 1 addition & 0 deletions common/llm_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .groq_llm_service import Groq
from .ollama import Ollama
from .huggingface_endpoint import HuggingFaceEndpoint
from .ibm_watsonx_service import IBMWatsonX
50 changes: 50 additions & 0 deletions common/llm_services/ibm_watsonx_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os

from common.llm_services import LLM_Model
from common.logs.log import req_id_cv
from common.logs.logwriter import LogWriter

logger = logging.getLogger(__name__)


class IBMWatsonX(LLM_Model):
def __init__(self, config):
super().__init__(config)
for auth_detail in config["authentication_configuration"].keys():
os.environ[auth_detail] = config["authentication_configuration"][
auth_detail
]

from langchain_ibm import ChatWatsonx

model_name = config["llm_model"]
self.llm = ChatWatsonx(
params={"temperature": config["model_kwargs"]["temperature"], "max_new_tokens": config["model_kwargs"]["max_new_tokens"]},
url=config["authentication_configuration"]["WATSONX_URL"],
apikey=config["authentication_configuration"]["WATSONX_APIKEY"],
model_id=model_name,
project_id=config["model_kwargs"]["project_id"]
)
self.prompt_path = config["prompt_path"]
LogWriter.info(
f"request_id={req_id_cv.get()} instantiated WatsonX model_name={model_name}"
)

@property
def map_question_schema_prompt(self):
return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt")

@property
def generate_function_prompt(self):
return self._read_prompt_file(self.prompt_path + "generate_function.txt")

@property
def entity_relationship_extraction_prompt(self):
return self._read_prompt_file(
self.prompt_path + "entity_relationship_extraction.txt"
)

@property
def model(self):
return self.llm
4 changes: 4 additions & 0 deletions copilot/app/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
HuggingFaceEndpoint,
Ollama,
OpenAI,
IBMWatsonX
)
from common.llm_services.base_llm import LLM_Model
from common.logs.log import req_id_cv
Expand Down Expand Up @@ -187,6 +188,9 @@ def make_agent(graphname, conn, use_cypher, ws: WebSocket = None, supportai_retr
elif llm_config["completion_service"]["llm_service"].lower() == "huggingface":
llm_service_name = "huggingface"
llm_provider = HuggingFaceEndpoint(llm_config["completion_service"])
elif llm_config["completion_service"]["llm_service"].lower() == "watsonx":
llm_service_name = "watsonx"
llm_provider = IBMWatsonX(llm_config["completion_service"])
else:
LogWriter.error(
f"/{graphname}/query_with_history request_id={req_id_cv.get()} agent creation failed due to invalid llm_service"
Expand Down
37 changes: 25 additions & 12 deletions copilot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ cryptography==42.0.5
dataclasses-json==0.5.14
distro==1.8.0
docker-pycreds==0.4.0
docstring_parser==0.16
emoji==2.8.0
environs==9.5.0
exceptiongroup==1.1.3
fastapi==0.103.1
filelock==3.15.4
filetype==1.2.0
frozenlist==1.4.0
fsspec==2024.6.0
gitdb==4.0.11
GitPython==3.1.40
google-api-core==2.14.0
Expand All @@ -51,24 +54,32 @@ h11==0.14.0
httpcore==0.18.0
httptools==0.6.0
httpx==0.25.0
huggingface_hub==0.23.0
huggingface-hub==0.23.0
ibm-cos-sdk==2.13.6
ibm-cos-sdk-core==2.13.6
ibm-cos-sdk-s3transfer==2.13.6
ibm_watsonx_ai==1.0.11
idna==3.4
importlib_metadata==8.0.0
iniconfig==2.0.0
isodate==0.6.1
jmespath==1.0.1
joblib==1.3.2
jq==1.6.0
jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.12
langchain-community==0.0.28
langchain-core==0.1.49
langchain-experimental==0.0.54
langchain-groq==0.1.3
langchain-text-splitters==0.0.1
langchain==0.2.8
langchain-community==0.2.7
langchain-core==0.2.19
langchain-experimental==0.0.62
langchain-groq==0.1.6
langchain-ibm==0.1.10
langchain-text-splitters==0.2.2
langchainhub==0.1.14
langdetect==1.0.9
langgraph==0.0.40
langsmith==0.1.24
langgraph==0.1.8
langsmith==0.1.86
lomond==0.3.3
lxml==4.9.3
marshmallow==3.20.1
matplotlib==3.9.1
Expand All @@ -82,6 +93,7 @@ orjson==3.9.15
packaging==23.2
pandas==2.1.1
pathtools==0.1.2
pluggy==1.5.0
prometheus_client==0.20.0
proto-plus==1.22.3
protobuf==4.24.4
Expand All @@ -95,18 +107,18 @@ pydantic==2.3.0
pydantic_core==2.6.3
pygit2==1.13.2
pymilvus==2.3.6
python-dateutil==2.8.2
pytest==8.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.0
python-iso639==2023.6.15
python-magic==0.4.27
pytest==8.2.0
pyTigerDriver==1.0.15
pyTigerGraph==1.6.2
pytz==2023.3.post1
PyYAML==6.0.1
rapidfuzz==3.4.0
regex==2023.10.3
requests==2.31.0
requests==2.32.2
rsa==4.9
s3transfer==0.7.0
scikit-learn==1.5.1
Expand Down Expand Up @@ -138,3 +150,4 @@ wandb==0.15.12
watchfiles==0.20.0
websockets==11.0.3
yarl==1.9.2
zipp==3.19.2
7 changes: 7 additions & 0 deletions copilot/tests/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ aws_bedrock_config="./configs/bedrock_config.json"
huggingface_llama3_script="test_huggingface_llama70b.py"
huggingface_llama3_config="./configs/huggingface_llama70b_config.json"

watsonx_mistral_large_script="test_watsonx_mistral-large.py"
watsonx_mistral_large_config="./configs/ibm_watsonx_config.json"

# Function to execute a service
execute_service() {
local service="$1"
Expand Down Expand Up @@ -100,6 +103,9 @@ case "$llm_service" in
"huggingface_llama3")
execute_service "$huggingface_llama3_script" "$huggingface_llama3_config"
;;
"watsonx_mistral_large")
execute_service "$watsonx_mistral_large_script" "$watsonx_mistral_large_config"
;;
"all")
echo "Executing all services..."
for service_script_pair in "$azure_gpt35_script $azure_gpt35_config" \
Expand All @@ -110,6 +116,7 @@ case "$llm_service" in
"$aws_bedrock_script $aws_bedrock_config" \
"$openai_gpt4o_script $openai_gpt4o_config" \
"$huggingface_llama3_script $huggingface_llama3_config" \
"$watsonx_mistral_large_script $watsonx_mistral_large_config" \
"$huggingface_phi3_script $huggingface_phi3_config"; do
execute_service $service_script_pair
done
Expand Down
59 changes: 59 additions & 0 deletions copilot/tests/test_watsonx_mistral-large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import unittest

import pytest
from fastapi.testclient import TestClient
from test_service import CommonTests
import wandb
import parse_test_config
import sys


@pytest.mark.skip(reason="All tests in this class are currently skipped by the pipeline, but used by the LLM regression tests.")
class TestWithWatsonX(CommonTests, unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
from main import app

cls.client = TestClient(app)
cls.llm_service = "mistralai/mistral-large"
if USE_WANDB:
cls.table = wandb.Table(columns=columns)


def test_config_read(self):
resp = self.client.get("/")
self.assertEqual(resp.json()["config"], "mistral-large")


if __name__ == "__main__":
parser = parse_test_config.create_parser()

args = parser.parse_known_args()[0]

USE_WANDB = args.wandb

schema = args.schema

if USE_WANDB:
columns = [
"LLM_Service",
"Dataset",
"Question Type",
"Question Theme",
"Question",
"True Answer",
"True Function Call",
"Retrieved Natural Language Answer",
"Retrieved Answer",
"Answer Source",
"Answer Correct",
"Answered Question",
"Response Time (seconds)",
]
CommonTests.setUpClass(schema)

# clean up args before unittesting
del sys.argv[1:]
unittest.main()

0 comments on commit a4e384c

Please sign in to comment.