diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index caf8f5b2ace..580dc80a48c 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -64,13 +64,14 @@ async def start_server(config: EverestConfig, debug: bool = False) -> Driver: return driver -def stop_server(server_context: tuple[str, str, tuple[str, str]], retries: int = 5): +def stop_server( + server_context: tuple[str, str, tuple[str, str]], retries: int = 5 +) -> bool: """ Stop server if found and it is running. """ for retry in range(retries): try: - print("stopping server") url, cert, auth = server_context stop_endpoint = "/".join([url, STOP_ENDPOINT]) response = requests.post( @@ -90,21 +91,25 @@ def stop_server(server_context: tuple[str, str, tuple[str, str]], retries: int = def start_experiment( server_context: tuple[str, str, tuple[str, str]], config: EverestConfig, + retries: int = 5, ) -> None: - print("Starting experiment") - try: - url, cert, auth = server_context - start_endpoint = "/".join([url, START_EXPERIMENT_ENDPOINT]) - response = requests.post( - start_endpoint, - verify=cert, - auth=auth, - proxies=PROXY, # type: ignore - json=config.to_dict(), - ) - response.raise_for_status() - except Exception as e: - raise ValueError("Failed to start experiment.") from e + for retry in range(retries): + try: + url, cert, auth = server_context + start_endpoint = "/".join([url, START_EXPERIMENT_ENDPOINT]) + response = requests.post( + start_endpoint, + verify=cert, + auth=auth, + proxies=PROXY, # type: ignore + json=config.to_dict(), + ) + response.raise_for_status() + return + except: + logging.debug(traceback.format_exc()) + time.sleep(retry) + raise ValueError("Failed to start experiment") def extract_errors_from_file(path: str): @@ -113,6 +118,26 @@ def extract_errors_from_file(path: str): return re.findall(r"(Error \w+.*)", content) +def wait_for_server_simple( + url: str, cert: str, auth: tuple[str, str], timeout: int +) -> None: + """ + Checks everest server has started _HTTP_REQUEST_RETRY times. Waits + progressively longer between each check. + + Raise an exception when the timeout is reached. + """ + sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1) + for retry_count in range(_HTTP_REQUEST_RETRY): + try: + requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) # type: ignore + return + except Exception: + sleep_time = sleep_time_increment * (2**retry_count) + time.sleep(sleep_time) + raise RuntimeError("Failed to get reply from server within configured timeout.") + + def wait_for_server(output_dir: str, timeout: int) -> None: """ Checks everest server has started _HTTP_REQUEST_RETRY times. Waits diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index fbd5ad59783..865c0e4d2f6 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -42,6 +42,7 @@ ServerStatus, get_opt_status, update_everserver_status, + wait_for_server_simple, ) from everest.export import check_for_errors from everest.plugins.everest_plugin_manager import EverestPluginManager @@ -184,8 +185,6 @@ def stop( ) -> Response: _log(request) _check_user(credentials) - print(f"STOP ENDPOINT {shared_data}") - shared_data[STOP_ENDPOINT] = True return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) @@ -228,7 +227,6 @@ def get_experiment_status( ) -> Response: _log(request) _check_user(credentials) - if shared_data[STOP_ENDPOINT]: return Response(f"{EverestExitCode.USER_ABORT}", 200) if runner is None: @@ -397,20 +395,13 @@ def main(): return try: - # add timeout - is_running = False - while not is_running: - try: - requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) # type: ignore - is_running = True - except: - time.sleep(1) + wait_for_server_simple(url, cert, auth, 60) update_everserver_status(status_path, ServerStatus.running) - # add timeout is_done = False exit_code = None + # loop unil the optimization is done while not is_done: response = requests.get( "/".join([url, EXPERIMENT_STATUS_ENDPOINT]), @@ -426,8 +417,8 @@ def main(): else: time.sleep(1) - response: requests.Response = requests.get( - url + "/" + SHARED_DATA_ENDPOINT, + response = requests.get( + "/".join([url, SHARED_DATA_ENDPOINT]), verify=cert, auth=auth, proxies=PROXY, # type: ignore @@ -441,7 +432,7 @@ def main(): if status != ServerStatus.completed: update_everserver_status(status_path, status, message) return - except Exception: + except: if shared_data[STOP_ENDPOINT]: update_everserver_status( status_path, @@ -476,7 +467,7 @@ def main(): data_frame=export_with_progress(config, export_ecl), export_path=config.export_path, ) - except Exception: + except: update_everserver_status( status_path, ServerStatus.failed, diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 0f24bddc4eb..3a3b40dcb59 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -5,9 +5,8 @@ from unittest.mock import patch import pytest -from fastapi import Response from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi.responses import JSONResponse, PlainTextResponse, Response from seba_sqlite.snapshot import SebaSnapshot from ert.run_models.everest_run_model import EverestExitCode