Skip to content

Commit

Permalink
dsl testing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jan 14, 2025
1 parent 5fae88a commit 70bb7e0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 38 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dev = [
"aiohttp",
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-asyncio",
"coverage",
"jinja2",
Expand Down
46 changes: 8 additions & 38 deletions test_elasticsearch/test_dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@
ELASTICSEARCH_URL = "http://localhost:9200"


def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch:
def get_test_client(url, wait: bool = True, **kwargs: Any) -> Elasticsearch:
# construct kwargs from the environment
kw: Dict[str, Any] = {"request_timeout": 30}

if "PYTHON_CONNECTION_CLASS" in os.environ:
kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"]

kw.update(kwargs)
client = Elasticsearch(ELASTICSEARCH_URL, **kw)
client = Elasticsearch(url, **kw)

# wait for yellow status
for tries_left in range(100 if wait else 1, 0, -1):
Expand All @@ -76,15 +76,15 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch:
raise SkipTest("Elasticsearch failed to start.")


async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasticsearch:
async def get_async_test_client(url, wait: bool = True, **kwargs: Any) -> AsyncElasticsearch:
# construct kwargs from the environment
kw: Dict[str, Any] = {"request_timeout": 30}

if "PYTHON_CONNECTION_CLASS" in os.environ:
kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"]

kw.update(kwargs)
client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw)
client = AsyncElasticsearch(url, **kw)

# wait for yellow status
for tries_left in range(100 if wait else 1, 0, -1):
Expand All @@ -100,36 +100,6 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti
raise SkipTest("Elasticsearch failed to start.")


class ElasticsearchTestCase(TestCase):
client: Elasticsearch

@staticmethod
def _get_client() -> Elasticsearch:
return get_test_client()

@classmethod
def setup_class(cls) -> None:
cls.client = cls._get_client()

def teardown_method(self, _: Any) -> None:
# Hidden indices expanded in wildcards in ES 7.7
expand_wildcards = ["open", "closed"]
if self.es_version() >= (7, 7):
expand_wildcards.append("hidden")

self.client.indices.delete_data_stream(
name="*", expand_wildcards=expand_wildcards
)
self.client.indices.delete(index="*", expand_wildcards=expand_wildcards)
self.client.indices.delete_template(name="*")
self.client.indices.delete_index_template(name="*")

def es_version(self) -> Tuple[int, ...]:
if not hasattr(self, "_es_version"):
self._es_version = _get_version(self.client.info()["version"]["number"])
return self._es_version


def _get_version(version_string: str) -> Tuple[int, ...]:
if "." not in version_string:
return ()
Expand All @@ -138,19 +108,19 @@ def _get_version(version_string: str) -> Tuple[int, ...]:


@fixture(scope="session")
def client() -> Elasticsearch:
def client(elasticsearch_url) -> Elasticsearch:
try:
connection = get_test_client(wait="WAIT_FOR_ES" in os.environ)
connection = get_test_client(elasticsearc_url, wait="WAIT_FOR_ES" in os.environ)
add_connection("default", connection)
return connection
except SkipTest:
skip()


@pytest_asyncio.fixture
async def async_client() -> AsyncGenerator[AsyncElasticsearch, None]:
async def async_client(elasticsearch_url) -> AsyncGenerator[AsyncElasticsearch, None]:
try:
connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ)
connection = await get_async_test_client(elasticsearch_url, wait="WAIT_FOR_ES" in os.environ)
add_async_connection("default", connection)
yield connection
await connection.close()
Expand Down

0 comments on commit 70bb7e0

Please sign in to comment.