diff --git a/pyproject.toml b/pyproject.toml index 33abbd5d9..0c66e2f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "aiohttp", "pytest", "pytest-cov", + "pytest-mock", "pytest-asyncio", "coverage", "jinja2", diff --git a/test_elasticsearch/test_dsl/conftest.py b/test_elasticsearch/test_dsl/conftest.py index 2e5fa91af..eb541ca6a 100644 --- a/test_elasticsearch/test_dsl/conftest.py +++ b/test_elasticsearch/test_dsl/conftest.py @@ -53,7 +53,7 @@ ELASTICSEARCH_URL = "http://localhost:9200" -def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: +def get_test_client(url, wait: bool = True, **kwargs: Any) -> Elasticsearch: # construct kwargs from the environment kw: Dict[str, Any] = {"request_timeout": 30} @@ -61,7 +61,7 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] kw.update(kwargs) - client = Elasticsearch(ELASTICSEARCH_URL, **kw) + client = Elasticsearch(url, **kw) # wait for yellow status for tries_left in range(100 if wait else 1, 0, -1): @@ -76,7 +76,7 @@ def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: raise SkipTest("Elasticsearch failed to start.") -async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasticsearch: +async def get_async_test_client(url, wait: bool = True, **kwargs: Any) -> AsyncElasticsearch: # construct kwargs from the environment kw: Dict[str, Any] = {"request_timeout": 30} @@ -84,7 +84,7 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] kw.update(kwargs) - client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw) + client = AsyncElasticsearch(url, **kw) # wait for yellow status for tries_left in range(100 if wait else 1, 0, -1): @@ -100,36 +100,6 @@ async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasti raise SkipTest("Elasticsearch failed to start.") -class ElasticsearchTestCase(TestCase): - client: Elasticsearch - - @staticmethod - def _get_client() -> Elasticsearch: - return get_test_client() - - @classmethod - def setup_class(cls) -> None: - cls.client = cls._get_client() - - def teardown_method(self, _: Any) -> None: - # Hidden indices expanded in wildcards in ES 7.7 - expand_wildcards = ["open", "closed"] - if self.es_version() >= (7, 7): - expand_wildcards.append("hidden") - - self.client.indices.delete_data_stream( - name="*", expand_wildcards=expand_wildcards - ) - self.client.indices.delete(index="*", expand_wildcards=expand_wildcards) - self.client.indices.delete_template(name="*") - self.client.indices.delete_index_template(name="*") - - def es_version(self) -> Tuple[int, ...]: - if not hasattr(self, "_es_version"): - self._es_version = _get_version(self.client.info()["version"]["number"]) - return self._es_version - - def _get_version(version_string: str) -> Tuple[int, ...]: if "." not in version_string: return () @@ -138,9 +108,9 @@ def _get_version(version_string: str) -> Tuple[int, ...]: @fixture(scope="session") -def client() -> Elasticsearch: +def client(elasticsearch_url) -> Elasticsearch: try: - connection = get_test_client(wait="WAIT_FOR_ES" in os.environ) + connection = get_test_client(elasticsearc_url, wait="WAIT_FOR_ES" in os.environ) add_connection("default", connection) return connection except SkipTest: @@ -148,9 +118,9 @@ def client() -> Elasticsearch: @pytest_asyncio.fixture -async def async_client() -> AsyncGenerator[AsyncElasticsearch, None]: +async def async_client(elasticsearch_url) -> AsyncGenerator[AsyncElasticsearch, None]: try: - connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ) + connection = await get_async_test_client(elasticsearch_url, wait="WAIT_FOR_ES" in os.environ) add_async_connection("default", connection) yield connection await connection.close()