Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nihit committed Jan 11, 2025
1 parent 6d9a039 commit 2753853
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 112 deletions.
99 changes: 99 additions & 0 deletions tests/assets/banking/config_banking_azureopenai.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"task_name": "BankingComplaintsClassification",
"task_type": "classification",
"dataset": {
"label_column": "label",
"delimiter": ","
},
"model": {
"provider": "azure_openai",
"name": "gpt-4o-mini"
},
"prompt": {
"task_guidelines": "You are an expert at understanding bank customers support complaints and queries.\nYour job is to correctly classify the provided input example into one of the following categories.\nCategories:\n{labels}",
"output_guidelines": "You will answer with just the the correct output label and nothing else.",
"labels": [
"activate_my_card",
"age_limit",
"apple_pay_or_google_pay",
"atm_support",
"automatic_top_up",
"balance_not_updated_after_bank_transfer",
"balance_not_updated_after_cheque_or_cash_deposit",
"beneficiary_not_allowed",
"cancel_transfer",
"card_about_to_expire",
"card_acceptance",
"card_arrival",
"card_delivery_estimate",
"card_linking",
"card_not_working",
"card_payment_fee_charged",
"card_payment_not_recognised",
"card_payment_wrong_exchange_rate",
"card_swallowed",
"cash_withdrawal_charge",
"cash_withdrawal_not_recognised",
"change_pin",
"compromised_card",
"contactless_not_working",
"country_support",
"declined_card_payment",
"declined_cash_withdrawal",
"declined_transfer",
"direct_debit_payment_not_recognised",
"disposable_card_limits",
"edit_personal_details",
"exchange_charge",
"exchange_rate",
"exchange_via_app",
"extra_charge_on_statement",
"failed_transfer",
"fiat_currency_support",
"get_disposable_virtual_card",
"get_physical_card",
"getting_spare_card",
"getting_virtual_card",
"lost_or_stolen_card",
"lost_or_stolen_phone",
"order_physical_card",
"passcode_forgotten",
"pending_card_payment",
"pending_cash_withdrawal",
"pending_top_up",
"pending_transfer",
"pin_blocked",
"receiving_money",
"Refund_not_showing_up",
"request_refund",
"reverted_card_payment?",
"supported_cards_and_currencies",
"terminate_account",
"top_up_by_bank_transfer_charge",
"top_up_by_card_charge",
"top_up_by_cash_or_cheque",
"top_up_failed",
"top_up_limits",
"top_up_reverted",
"topping_up_by_card",
"transaction_charged_twice",
"transfer_fee_charged",
"transfer_into_account",
"transfer_not_received_by_recipient",
"transfer_timing",
"unable_to_verify_identity",
"verify_my_identity",
"verify_source_of_funds",
"verify_top_up",
"virtual_card_not_working",
"visa_or_mastercard",
"why_verify_identity",
"wrong_amount_of_cash_received",
"wrong_exchange_rate_for_cash_withdrawal"
],
"few_shot_examples": "seed.csv",
"few_shot_selection": "semantic_similarity",
"few_shot_num": 5,
"example_template": "Input: {example}\nOutput: {label}"
}
}
118 changes: 6 additions & 112 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from autolabel.configs import AutolabelConfig
from autolabel.models.anthropic import AnthropicLLM
from autolabel.models.azure_openai import AzureOpenAILLM
from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.refuelV2 import RefuelLLMV2


################### ANTHROPIC TESTS #######################
Expand Down Expand Up @@ -198,115 +198,9 @@ def test_gpt4V_return_probs():

################### OPENAI GPT 4V TESTS #######################


################### REFUEL TESTS #######################
def test_refuel_initialization():
model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)


async def test_refuel_label(mocker):
class PostRequestMockResponse:
def __init__(self, resp, status_code):
self.resp = resp
self.status_code = status_code

def json(self):
return self.resp

def raise_for_status(self):
pass

model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)
prompts = ["test1", "test2"]
mocker.patch(
"requests.post",
return_value=PostRequestMockResponse(
resp='{"generated_text": "Answers"}', status_code=200,
),
)
x = await model.label(prompts)
assert [i[0].text for i in x.generations] == ["Answers", "Answers"]
assert sum(x.costs) == 0


async def test_refuel_label_non_retryable(mocker):
class PostRequestMockResponse:
def __init__(self, resp, status_code):
self.resp = resp
self.status_code = status_code
self.text = resp

def json(self):
return self.resp

def raise_for_status(self):
pass

model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)
prompts = ["test1", "test2"]
mocker.patch(
"requests.post",
return_value=PostRequestMockResponse(
resp='{"error_message": "Error123"}', status_code=422,
),
################### AZURE OPENAI TESTS #######################
def test_azure_openai_initialization():
model = AzureOpenAILLM(
config=AutolabelConfig(config="tests/assets/banking/config_banking_azureopenai.json"),
)
x = await model.label(prompts)
assert [i[0].text for i in x.generations] == ["", ""]
for error in x.errors:
assert "NonRetryable Error:" in error.error_message
assert sum(x.costs) == 0


async def test_refuel_label_retryable(mocker):
class PostRequestMockResponse:
def __init__(self, resp, status_code):
self.resp = resp
self.status_code = status_code
self.text = resp

def json(self):
return self.resp

def raise_for_status(self):
pass

model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)
prompts = ["test1", "test2"]
mocker.patch(
"requests.post",
return_value=PostRequestMockResponse(
resp='{"error_message": "Error123"}', status_code=500,
),
)
x = await model.label(prompts)
assert [i[0].text for i in x.generations] == ["", ""]
for error in x.errors:
assert "NonRetryable Error:" not in error.error_message
assert sum(x.costs) == 0


def test_refuel_get_cost():
model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)
example_prompt = "TestingExamplePrompt"
curr_cost = model.get_cost(example_prompt)
assert curr_cost == 0


def test_refuel_return_probs():
model = RefuelLLMV2(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json"),
)
assert model.returns_token_probs() is True


################### REFUEL TESTS #######################
assert model.model_name == "gpt-4o-mini"

0 comments on commit 2753853

Please sign in to comment.