Skip to content

Commit

Permalink
Exit with non-zero if 1 or more tests failed
Browse files Browse the repository at this point in the history
Signed-off-by: Radoslav Dimitrov <[email protected]>
  • Loading branch information
rdimitrov committed Jan 17, 2025
1 parent 19bbab7 commit 9779706
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions tests/integration/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import re
import sys
from typing import Optional

import requests
Expand All @@ -17,6 +18,7 @@
class CodegateTestRunner:
def __init__(self):
self.requester_factory = RequesterFactory()
self.failed_tests = [] # Track failed tests

def call_codegate(
self, url: str, headers: dict, data: dict, provider: str
Expand Down Expand Up @@ -119,7 +121,7 @@ def replacement(match):
pattern = r"ENV\w*"
return re.sub(pattern, replacement, input_string)

async def run_test(self, test: dict, test_headers: dict) -> None:
async def run_test(self, test: dict, test_headers: dict) -> bool:
test_name = test["name"]
url = test["url"]
data = json.loads(test["data"])
Expand All @@ -129,7 +131,7 @@ async def run_test(self, test: dict, test_headers: dict) -> None:
response = self.call_codegate(url, test_headers, data, provider)
if not response:
logger.error(f"Test {test_name} failed: No response received")
return
return False

# Debug response info
logger.debug(f"Response status: {response.status_code}")
Expand All @@ -142,22 +144,29 @@ async def run_test(self, test: dict, test_headers: dict) -> None:
checks = CheckLoader.load(test)

# Run all checks
passed = True
all_passed = True
for check in checks:
passed_check = await check.run_check(parsed_response, test)
if not passed_check:
passed = False
logger.info(f"Test {test_name} passed" if passed else f"Test {test_name} failed")
all_passed = False

if not all_passed:
self.failed_tests.append(test_name)

logger.info(f"Test {test_name} {'passed' if all_passed else 'failed'}")
return all_passed

except Exception as e:
logger.exception("Could not parse response: %s", e)
self.failed_tests.append(test_name)
return False

async def run_tests(
self,
testcases_file: str,
providers: Optional[list[str]] = None,
test_names: Optional[list[str]] = None,
) -> None:
) -> bool:
with open(testcases_file, "r") as f:
tests = yaml.safe_load(f)

Expand Down Expand Up @@ -187,7 +196,7 @@ async def run_tests(
if test_names:
filter_msg.append(f"test names: {', '.join(test_names)}")
logger.warning(f"No tests found for {' and '.join(filter_msg)}")
return
return True # No tests is not a failure

test_count = len(testcases)
filter_msg = []
Expand All @@ -201,12 +210,20 @@ async def run_tests(
+ (f" for {' and '.join(filter_msg)}" if filter_msg else "")
)

all_tests_passed = True
for test_id, test_data in testcases.items():
test_headers = headers.get(test_data["provider"], {})
test_headers = {
k: self.replace_env_variables(v, os.environ) for k, v in test_headers.items()
}
await self.run_test(test_data, test_headers)
test_passed = await self.run_test(test_data, test_headers)
if not test_passed:
all_tests_passed = False

if not all_tests_passed:
logger.error(f"The following tests failed: {', '.join(self.failed_tests)}")

return all_tests_passed


async def main():
Expand All @@ -225,10 +242,14 @@ async def main():
if test_names_env:
test_names = [t.strip() for t in test_names_env.split(",") if t.strip()]

await test_runner.run_tests(
all_tests_passed = await test_runner.run_tests(
"./tests/integration/testcases.yaml", providers=providers, test_names=test_names
)

# Exit with status code 1 if any tests failed
if not all_tests_passed:
sys.exit(1)


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit 9779706

Please sign in to comment.