Skip to content

Commit

Permalink
dsl testing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jan 14, 2025
1 parent 5fae88a commit 4115c26
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 39 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dev = [
"aiohttp",
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-asyncio",
"coverage",
"jinja2",
Expand Down
54 changes: 15 additions & 39 deletions test_elasticsearch/test_dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, Generator, Tuple, cast
from unittest import SkipTest, TestCase
from unittest import SkipTest
from unittest.mock import AsyncMock, Mock

import pytest_asyncio
Expand Down Expand Up @@ -53,15 +53,15 @@
ELASTICSEARCH_URL = "http://localhost:9200"


def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch:
def get_test_client(url, wait: bool = True, **kwargs: Any) -> Elasticsearch:
# construct kwargs from the environment
kw: Dict[str, Any] = {"request_timeout": 30}

if "PYTHON_CONNECTION_CLASS" in os.environ:
kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"]

kw.update(kwargs)
client = Elasticsearch(ELASTICSEARCH_URL, **kw)
client = Elasticsearch(url, **kw)

# wait for yellow status
for tries_left in range(100 if wait else 1, 0, -1):
Expand All @@ -76,15 +76,17 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch:
raise SkipTest("Elasticsearch failed to start.")


async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasticsearch:
async def get_async_test_client(
url, wait: bool = True, **kwargs: Any
) -> AsyncElasticsearch:
# construct kwargs from the environment
kw: Dict[str, Any] = {"request_timeout": 30}

if "PYTHON_CONNECTION_CLASS" in os.environ:
kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"]

kw.update(kwargs)
client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw)
client = AsyncElasticsearch(url, **kw)

# wait for yellow status
for tries_left in range(100 if wait else 1, 0, -1):
Expand All @@ -100,36 +102,6 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti
raise SkipTest("Elasticsearch failed to start.")


class ElasticsearchTestCase(TestCase):
client: Elasticsearch

@staticmethod
def _get_client() -> Elasticsearch:
return get_test_client()

@classmethod
def setup_class(cls) -> None:
cls.client = cls._get_client()

def teardown_method(self, _: Any) -> None:
# Hidden indices expanded in wildcards in ES 7.7
expand_wildcards = ["open", "closed"]
if self.es_version() >= (7, 7):
expand_wildcards.append("hidden")

self.client.indices.delete_data_stream(
name="*", expand_wildcards=expand_wildcards
)
self.client.indices.delete(index="*", expand_wildcards=expand_wildcards)
self.client.indices.delete_template(name="*")
self.client.indices.delete_index_template(name="*")

def es_version(self) -> Tuple[int, ...]:
if not hasattr(self, "_es_version"):
self._es_version = _get_version(self.client.info()["version"]["number"])
return self._es_version


def _get_version(version_string: str) -> Tuple[int, ...]:
if "." not in version_string:
return ()
Expand All @@ -138,19 +110,23 @@ def _get_version(version_string: str) -> Tuple[int, ...]:


@fixture(scope="session")
def client() -> Elasticsearch:
def client(elasticsearch_url) -> Elasticsearch:
try:
connection = get_test_client(wait="WAIT_FOR_ES" in os.environ)
connection = get_test_client(
elasticsearch_url, wait="WAIT_FOR_ES" in os.environ
)
add_connection("default", connection)
return connection
except SkipTest:
skip()


@pytest_asyncio.fixture
async def async_client() -> AsyncGenerator[AsyncElasticsearch, None]:
async def async_client(elasticsearch_url) -> AsyncGenerator[AsyncElasticsearch, None]:
try:
connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ)
connection = await get_async_test_client(
elasticsearch_url, wait="WAIT_FOR_ES" in os.environ
)
add_async_connection("default", connection)
yield connection
await connection.close()
Expand Down

0 comments on commit 4115c26

Please sign in to comment.