Skip to content

Commit

Permalink
refactor: use asyncio Events over threading Events (#6115)
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Martinez <[email protected]>
Co-authored-by: Joan Martinez <[email protected]>
  • Loading branch information
JoanFM and JoanFM authored Nov 24, 2023
1 parent 6cd3311 commit 7aab7e6
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 74 deletions.
83 changes: 43 additions & 40 deletions jina/orchestrate/pods/container.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import asyncio
import copy
import multiprocessing
import os
import platform
import re
import signal
import threading
import time
from typing import TYPE_CHECKING, Dict, Optional, Union

Expand All @@ -25,21 +23,21 @@
from jina.parsers import set_gateway_parser

if TYPE_CHECKING: # pragma: no cover
import threading
from docker.client import DockerClient


def _docker_run(
client: 'DockerClient',
args: 'argparse.Namespace',
container_name: str,
envs: Dict,
net_mode: Optional[str],
logger: 'JinaLogger',
client: 'DockerClient',
args: 'argparse.Namespace',
container_name: str,
envs: Dict,
net_mode: Optional[str],
logger: 'JinaLogger',
):
# important to notice, that client is not assigned as instance member to avoid potential
# heavy copy into new process memory space
import warnings

import docker

docker_version = client.version().get('Version')
Expand Down Expand Up @@ -116,7 +114,7 @@ def _docker_run(

_volumes = {}
if not getattr(args, 'disable_auto_volume', None) and not getattr(
args, 'volumes', None
args, 'volumes', None
):
(
generated_volumes,
Expand Down Expand Up @@ -177,16 +175,16 @@ def _docker_run(


def run(
args: 'argparse.Namespace',
name: str,
container_name: str,
net_mode: Optional[str],
runtime_ctrl_address: str,
envs: Dict,
is_started: Union['multiprocessing.Event', 'threading.Event'],
is_shutdown: Union['multiprocessing.Event', 'threading.Event'],
is_ready: Union['multiprocessing.Event', 'threading.Event'],
is_signal_handlers_installed: Union['multiprocessing.Event', 'threading.Event'],
args: 'argparse.Namespace',
name: str,
container_name: str,
net_mode: Optional[str],
runtime_ctrl_address: str,
envs: Dict,
is_started: Union['multiprocessing.Event', 'threading.Event'],
is_shutdown: Union['multiprocessing.Event', 'threading.Event'],
is_ready: Union['multiprocessing.Event', 'threading.Event'],
is_signal_handlers_installed: Union['multiprocessing.Event', 'threading.Event'],
):
"""Method to be run in a process that stream logs from a Container
Expand All @@ -213,33 +211,38 @@ def run(
:param is_ready: concurrency event to communicate runtime is ready to receive messages
"""
import docker
import asyncio

log_kwargs = copy.deepcopy(vars(args))
log_kwargs['log_config'] = 'docker'
logger = JinaLogger(name, **log_kwargs)

cancel = threading.Event()
fail_to_start = threading.Event()
cancel = False
fail_to_start = False

def _set_cancel(*args, **kwargs):
cancel = True

if not __windows__:

try:
for signame in {signal.SIGINT, signal.SIGTERM}:
signal.signal(signame, lambda *args, **kwargs: cancel.set())
signal.signal(signame, _set_cancel)
except (ValueError, RuntimeError) as exc:
logger.warning(
f'The process starting the container for {name} will not be able to handle termination signals. '
f' {repr(exc)}'
)
else:
with ImportExtensions(
required=True,
logger=logger,
help_text='''If you see a 'DLL load failed' error, please reinstall `pywin32`.
required=True,
logger=logger,
help_text='''If you see a 'DLL load failed' error, please reinstall `pywin32`.
If you're using conda, please use the command `conda install -c anaconda pywin32`''',
):
import win32api

win32api.SetConsoleCtrlHandler(lambda *args, **kwargs: cancel.set(), True)
win32api.SetConsoleCtrlHandler(_set_cancel, True)

is_signal_handlers_installed.set()
client = docker.from_env()
Expand Down Expand Up @@ -272,23 +275,23 @@ def _is_container_alive(container) -> bool:

async def _check_readiness(container):
while (
_is_container_alive(container)
and not _is_ready()
and not cancel.is_set()
_is_container_alive(container)
and not _is_ready()
and not cancel
):
await asyncio.sleep(0.1)
if _is_container_alive(container):
is_started.set()
is_ready.set()
else:
fail_to_start.set()
fail_to_start = True

async def _stream_starting_logs(container):
for line in container.logs(stream=True):
if (
not is_started.is_set()
and not fail_to_start.is_set()
and not cancel.is_set()
not is_started.is_set()
and not fail_to_start
and not cancel
):
await asyncio.sleep(0.01)
msg = line.decode().rstrip() # type: str
Expand Down Expand Up @@ -318,9 +321,9 @@ class ContainerPod(BasePod):
def __init__(self, args: 'argparse.Namespace'):
super().__init__(args)
if (
self.args.docker_kwargs
and 'extra_hosts' in self.args.docker_kwargs
and __docker_host__ in self.args.docker_kwargs['extra_hosts']
self.args.docker_kwargs
and 'extra_hosts' in self.args.docker_kwargs
and __docker_host__ in self.args.docker_kwargs['extra_hosts']
):
self.args.docker_kwargs.pop('extra_hosts')
self._net_mode = None
Expand All @@ -336,9 +339,9 @@ def _get_control_address(self):
network = get_docker_network(client)

if (
self.args.docker_kwargs
and 'extra_hosts' in self.args.docker_kwargs
and __docker_host__ in self.args.docker_kwargs['extra_hosts']
self.args.docker_kwargs
and 'extra_hosts' in self.args.docker_kwargs
and __docker_host__ in self.args.docker_kwargs['extra_hosts']
):
ctrl_host = __docker_host__
elif network:
Expand Down
3 changes: 0 additions & 3 deletions jina/serve/networking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union, AsyncGenerator

import grpc
Expand All @@ -24,8 +23,6 @@
from jina.types.request.data import SingleDocumentRequest

if TYPE_CHECKING: # pragma: no cover
import threading

from grpc.aio._interceptor import ClientInterceptor
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down
2 changes: 1 addition & 1 deletion jina/serve/runtimes/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import asyncio
import signal
import threading
import time
from typing import TYPE_CHECKING, Optional, Union

Expand All @@ -21,6 +20,7 @@
from jina.serve.runtimes.servers import BaseServer

if TYPE_CHECKING: # pragma: no cover
import threading
import multiprocessing

HANDLED_SIGNALS = (
Expand Down
2 changes: 0 additions & 2 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import itertools
import threading
from typing import TYPE_CHECKING, AsyncIterator, Dict

from jina.enums import ProtocolType
Expand Down
1 change: 0 additions & 1 deletion jina/serve/runtimes/gateway/streamer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import os
import threading
from typing import (
TYPE_CHECKING,
AsyncIterator,
Expand Down
3 changes: 1 addition & 2 deletions jina/serve/runtimes/head/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import asyncio
import json
import os
import threading
from collections import defaultdict
from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Optional, Tuple, Any

Expand Down Expand Up @@ -140,7 +139,7 @@ def __init__(
self._gathering_endpoints = False
self.runtime_name = runtime_name
self._pydantic_models_by_endpoint = None
self.endpoints_discovery_stop_event = threading.Event()
self.endpoints_discovery_stop_event = asyncio.Event()
self.endpoints_discovery_task = None
if docarray_v2:
self.endpoints_discovery_task = asyncio.create_task(
Expand Down
56 changes: 31 additions & 25 deletions jina/serve/runtimes/servers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import threading
import time
import asyncio
from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Optional, Union

Expand All @@ -12,6 +12,7 @@

if TYPE_CHECKING:
import multiprocessing
import threading

from jina.serve.runtimes.gateway.request_handling import GatewayRequestHandler
from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler
Expand All @@ -23,18 +24,23 @@ class BaseServer(MonitoringMixin, InstrumentationMixin):
"""

def __init__(
self,
name: Optional[str] = 'gateway',
runtime_args: Optional[Dict] = None,
req_handler_cls=None,
req_handler=None,
is_cancel=None,
**kwargs,
self,
name: Optional[str] = 'gateway',
runtime_args: Optional[Dict] = None,
req_handler_cls=None,
req_handler=None,
is_cancel=None,
**kwargs,
):
self.name = name or ''
self.runtime_args = runtime_args
self.works_as_load_balancer = False
self.is_cancel = is_cancel or threading.Event()
try:
self.is_cancel = is_cancel or asyncio.Event()
except:
# in some unit tests we instantiate the server without an asyncio Loop
import threading
self.is_cancel = threading.Event()
if isinstance(runtime_args, Dict):
self.works_as_load_balancer = runtime_args.get(
'gateway_load_balancer', False
Expand Down Expand Up @@ -186,11 +192,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):

@staticmethod
def is_ready(
ctrl_address: str,
protocol: Optional[str] = 'grpc',
timeout: float = 1.0,
logger=None,
**kwargs,
ctrl_address: str,
protocol: Optional[str] = 'grpc',
timeout: float = 1.0,
logger=None,
**kwargs,
) -> bool:
"""
Check if status is ready.
Expand All @@ -213,11 +219,11 @@ def is_ready(

@staticmethod
async def async_is_ready(
ctrl_address: str,
protocol: Optional[str] = 'grpc',
timeout: float = 1.0,
logger=None,
**kwargs,
ctrl_address: str,
protocol: Optional[str] = 'grpc',
timeout: float = 1.0,
logger=None,
**kwargs,
) -> bool:
"""
Check if status is ready.
Expand All @@ -240,12 +246,12 @@ async def async_is_ready(

@classmethod
def wait_for_ready_or_shutdown(
cls,
timeout: Optional[float],
ready_or_shutdown_event: Union['multiprocessing.Event', 'threading.Event'],
ctrl_address: str,
health_check: bool = False,
**kwargs,
cls,
timeout: Optional[float],
ready_or_shutdown_event: Union['multiprocessing.Event', 'threading.Event', 'asyncio.Event'],
ctrl_address: str,
health_check: bool = False,
**kwargs,
):
"""
Check if the runtime has successfully started
Expand Down

0 comments on commit 7aab7e6

Please sign in to comment.