Skip to content

Commit

Permalink
Add AzureOpenAI model provider (#949)
Browse files Browse the repository at this point in the history
* Add AzureOpenAI model provider

* Update tests

* exclude test paths from ruff

* set test env vars

* Add support for parsing logprobs
  • Loading branch information
nihit authored Jan 13, 2025
1 parent 32bc80a commit 43ace07
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 113 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ clean-build:
clean: clean-pyc clean-test clean-build

test: clean
OPENAI_API_KEY=test_key ANTHROPIC_API_KEY=test_key REFUEL_API_KEY=test_key pytest
OPENAI_API_KEY=test_key ANTHROPIC_API_KEY=test_key REFUEL_API_KEY=test_key AZURE_OPENAI_API_KEY=test_key AZURE_OPENAI_ENDPOINT=test_key AZURE_OPENAI_API_VERSION=test_key pytest

check: test

Expand Down
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ combine-as-imports = true
default = ["third-party"]
first_party = ["first-party"]
standard_library = ["stdlib"]

[format]
exclude = ["*.pyi", "tests/unit/*"]
2 changes: 2 additions & 0 deletions src/autolabel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.vllm import VLLMModel
from autolabel.models.azure_openai import AzureOpenAILLM

MODEL_REGISTRY = {

Check failure on line 24 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

src/autolabel/models/__init__.py:13:1: I001 Import block is un-sorted or un-formatted
ModelProvider.OPENAI: OpenAILLM,
Expand All @@ -30,6 +31,7 @@
ModelProvider.HUGGINGFACE_PIPELINE_VISION: HFPipelineMultimodal,
ModelProvider.GOOGLE: GoogleLLM,
ModelProvider.VLLM: VLLMModel,
ModelProvider.AZURE_OPENAI: AzureOpenAILLM,
}


Expand Down
200 changes: 200 additions & 0 deletions src/autolabel/models/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@

import json
import logging
import os
from functools import partial
from time import time
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import ErrorType, LabelingError, RefuelLLMResult

logger = logging.getLogger(__name__)


class AzureOpenAILLM(BaseModel):
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_PARAMS = {
"max_tokens": 1000,
"temperature": 0.0,
"timeout": 30,
"logprobs": True,
"stream": False
}

# Reference: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
COST_PER_PROMPT_TOKEN = {
"gpt-35-turbo": 1 / 1_000_000,
"gpt-4-turbo-2024-04-09": 10 / 1_000_000,
"gpt-4o": 2.50 / 1_000_000,
"gpt-4o-mini": 0.15 / 1_000_000,
}
COST_PER_COMPLETION_TOKEN = {
"gpt-35-turbo": 2 / 1_000_000,
"gpt-4-turbo-2024-04-09": 30 / 1_000_000,
"gpt-4o": 10 / 1_000_000,
"gpt-4o-mini": 0.60 / 1_000_000,
}

MODELS_WITH_TOKEN_PROBS = set(
[
"gpt-35-turbo",
"gpt-4",
"gpt-4o",
"gpt-4o-mini"
]
)

MODELS_WITH_STRUCTURED_OUTPUTS = set(
[
"gpt-4o-mini",
"gpt-4o",
],
)

ERROR_TYPE_MAPPING = {
"context_length_exceeded": ErrorType.CONTEXT_LENGTH_ERROR,
"rate_limit_exceeded": ErrorType.RATE_LIMIT_ERROR,
}

def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)
try:
from openai import AzureOpenAI
import tiktoken
except ImportError:
raise ImportError(
"openai is required to use the AzureOpenAILLM. Please install it with: pip install 'refuel-autolabel[openai]'"
)

self.tiktoken = tiktoken
self.model_name = config.model_name() or self.DEFAULT_MODEL
model_params = config.model_params()
self.model_params = {**self.DEFAULT_PARAMS, **model_params}

required_env_vars = [
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_API_VERSION"
]
for var in required_env_vars:
if os.getenv(var) is None:
raise ValueError(f"{var} environment variable not set")

self.client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)
self.llm = partial(
self.client.chat.completions.create,
model=self.model_name,
**self.model_params
)


def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult:
generations = []
errors = []
latencies = []
for prompt in prompts:
content = [{"type": "text", "text": prompt}]
start_time = time()
try:
if (
output_schema is not None
and self.model_name in self.MODELS_WITH_STRUCTURED_OUTPUTS
):
result = self.llm(
messages=[{"role": "user", "content": content}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "response_format",
"schema": output_schema,
"strict": True,
},
},
).to_dict()
else:
result = self.llm(
messages=[{"role": "user", "content": content}],
).to_dict()

result = result['choices'][0]
result_completion = result['message']['content']
parsed_logprobs = {"top_logprobs": []}
if 'logprobs' in result:
result_logprobs = result['logprobs']['content'] # List of dicts with token and logprob
for curr_token in result_logprobs:
parsed_logprobs["top_logprobs"].append(
{curr_token["token"]: curr_token["logprob"]},
)

generations.append(
[
Generation(
text=result_completion,
generation_info={"logprobs": parsed_logprobs},
),
],
)
errors.append(None)
except Exception as e:
logger.error(f"Error generating label: {e}")
generations.append(
[
Generation(
text="",
generation_info=None,
),
],
)
errors.append(
LabelingError(
error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e),
),
)
end_time = time()
latencies.append(end_time - start_time)

return RefuelLLMResult(
generations=generations,
errors=errors,
latencies=latencies,
)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
encoding = self.tiktoken.encoding_for_model(self.model_name)
num_prompt_tokens = len(encoding.encode(prompt))

if label:
num_completion_tokens = len(encoding.encode(label))
else:
num_completion_tokens = self.model_params["max_tokens"]

return (
num_prompt_tokens * self.COST_PER_PROMPT_TOKEN[self.model_name] +
num_completion_tokens * self.COST_PER_COMPLETION_TOKEN[self.model_name]
)

def returns_token_probs(self) -> bool:
return (
self.model_name is not None
and self.model_name in self.MODELS_WITH_TOKEN_PROBS
)

def get_num_tokens(self, prompt: str) -> int:
encoding = self.tiktoken.encoding_for_model(self.model_name)
return len(encoding.encode(prompt))

1 change: 1 addition & 0 deletions src/autolabel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ModelProvider(str, Enum):
CUSTOM = "custom"
TGI = "tgi"
VLLM = "vllm"
AZURE_OPENAI = "azure_openai"


class TaskType(str, Enum):
Expand Down
99 changes: 99 additions & 0 deletions tests/assets/banking/config_banking_azureopenai.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"task_name": "BankingComplaintsClassification",
"task_type": "classification",
"dataset": {
"label_column": "label",
"delimiter": ","
},
"model": {
"provider": "azure_openai",
"name": "gpt-4o-mini"
},
"prompt": {
"task_guidelines": "You are an expert at understanding bank customers support complaints and queries.\nYour job is to correctly classify the provided input example into one of the following categories.\nCategories:\n{labels}",
"output_guidelines": "You will answer with just the the correct output label and nothing else.",
"labels": [
"activate_my_card",
"age_limit",
"apple_pay_or_google_pay",
"atm_support",
"automatic_top_up",
"balance_not_updated_after_bank_transfer",
"balance_not_updated_after_cheque_or_cash_deposit",
"beneficiary_not_allowed",
"cancel_transfer",
"card_about_to_expire",
"card_acceptance",
"card_arrival",
"card_delivery_estimate",
"card_linking",
"card_not_working",
"card_payment_fee_charged",
"card_payment_not_recognised",
"card_payment_wrong_exchange_rate",
"card_swallowed",
"cash_withdrawal_charge",
"cash_withdrawal_not_recognised",
"change_pin",
"compromised_card",
"contactless_not_working",
"country_support",
"declined_card_payment",
"declined_cash_withdrawal",
"declined_transfer",
"direct_debit_payment_not_recognised",
"disposable_card_limits",
"edit_personal_details",
"exchange_charge",
"exchange_rate",
"exchange_via_app",
"extra_charge_on_statement",
"failed_transfer",
"fiat_currency_support",
"get_disposable_virtual_card",
"get_physical_card",
"getting_spare_card",
"getting_virtual_card",
"lost_or_stolen_card",
"lost_or_stolen_phone",
"order_physical_card",
"passcode_forgotten",
"pending_card_payment",
"pending_cash_withdrawal",
"pending_top_up",
"pending_transfer",
"pin_blocked",
"receiving_money",
"Refund_not_showing_up",
"request_refund",
"reverted_card_payment?",
"supported_cards_and_currencies",
"terminate_account",
"top_up_by_bank_transfer_charge",
"top_up_by_card_charge",
"top_up_by_cash_or_cheque",
"top_up_failed",
"top_up_limits",
"top_up_reverted",
"topping_up_by_card",
"transaction_charged_twice",
"transfer_fee_charged",
"transfer_into_account",
"transfer_not_received_by_recipient",
"transfer_timing",
"unable_to_verify_identity",
"verify_my_identity",
"verify_source_of_funds",
"verify_top_up",
"virtual_card_not_working",
"visa_or_mastercard",
"why_verify_identity",
"wrong_amount_of_cash_received",
"wrong_exchange_rate_for_cash_withdrawal"
],
"few_shot_examples": "seed.csv",
"few_shot_selection": "semantic_similarity",
"few_shot_num": 5,
"example_template": "Input: {example}\nOutput: {label}"
}
}
Loading

0 comments on commit 43ace07

Please sign in to comment.