Skip to content

Commit

Permalink
mypy-compliant typing (#9)
Browse files Browse the repository at this point in the history
* mypy-compliant type hints
* added tests for python 3.12
* disabled python 3.6 support
  • Loading branch information
s4v4g3 authored Nov 23, 2023
1 parent 743c762 commit a5cc393
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 59 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,21 @@ jobs:
build:
env:
# We use these variables to convert between tox and GHA version literals
py36: 3.6
py37: 3.7
py38: 3.8
py39: 3.9
py310: "3.10"
py311: "3.11"
py312: "3.12"
pypy3: pypy-3.7
RUN_MATRIX_COMBINATION: ${{ matrix.python-version }}-${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # ensures the entire test matrix is run, even if one permutation fails
matrix:
python-version: [ py36, py37, py38, py39, py310, py311 ]
python-version: [ py37, py38, py39, py310, py311, py312 ]
os: [ ubuntu-20.04, windows-2019 ]
exclude:
- os: windows-2019
python-version: py36

steps:
- name: Checkout Core Repo @ SHA - ${{ github.sha }}
uses: actions/checkout@v2
Expand All @@ -30,7 +28,7 @@ jobs:
python-version: ${{ env[matrix.python-version] }}
architecture: 'x64'
- name: Install Tox
run: pip install tox==3.27.1 -U tox-factor
run: pip install tox==4.* -U
- name: Cache tox environment
# Preserves .tox directory between runs for faster installs
uses: actions/cache@v2
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
venv*/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
6 changes: 3 additions & 3 deletions example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


@instrumented
def foobar():
def foobar() -> None:
# automatically creates a "foobar" span due to "instrumented" decorator
time.sleep(random.randint(0, 10) / 10.0)


@instrumented(span_name="my custom span name")
def bar():
def bar() -> None:
# automatically creates a span due to "instrumented" decorator, using a custom name

# set an attribute for the span
Expand All @@ -24,7 +24,7 @@ def bar():


@instrumented
def foo():
def foo() -> None:
# automatically creates a span due to "instrumented" decorator
for _ in range(0, random.randint(10, 20)):
bar()
Expand Down
118 changes: 73 additions & 45 deletions otel_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Callable, Optional, Dict
from typing import Callable, Optional, Dict, Any, TYPE_CHECKING, Union, cast, TypeVar, overload
import os
from functools import wraps
import logging
import inspect
from opentelemetry import context, trace
from opentelemetry.trace import Tracer
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
import importlib
import warnings
from opentelemetry.util.types import AttributeValue as SpanAttributeValue

CallableType = Callable[..., Any]
DecoratedFuncType = TypeVar("DecoratedFuncType", bound=CallableType)

__all__ = [
"TelemetryOptions",
"TraceEventLogHandler",
Expand All @@ -21,9 +26,8 @@
"inject_context_to_env",
]

global_tracer_provider: Optional[object] = None
tracer_providers_by_service_name: Dict[str, object] = {}
span_processors = []
global_tracer_provider: Optional[TracerProvider] = None
tracer_providers_by_service_name: Dict[str, TracerProvider] = {}


class TelemetryOptions:
Expand All @@ -37,7 +41,7 @@ class TelemetryOptions:
OTEL_PROCESSOR_TYPE: str = "batch"
TRACEPARENT: Optional[str] = None

def __init__(self, *_args, **kwargs):
def __init__(self, *_args: Any, **kwargs: Any) -> None:
all_attrs = [attr for attr in dir(self.__class__) if not attr.startswith("_")]
# set default values from env
for attr in all_attrs:
Expand All @@ -55,76 +59,81 @@ class TraceContextCarrier:

traceparent_var = "TRACEPARENT"

def __init__(self, carrier: Optional[dict] = None):
self.token = None
self.carrier = carrier
def __init__(self, carrier: Optional[Dict[str, str]] = None):
self.token: Optional[object] = None
if carrier is None:
self.carrier = {}
TraceContextTextMapPropagator().inject(self.carrier)
carrier = {}
TraceContextTextMapPropagator().inject(carrier)
self.carrier: Dict[str, str] = carrier

def __enter__(self):
def __enter__(self) -> "TraceContextCarrier":
if self.token is None:
self.token = self.__attach(self.carrier)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.detach()

@classmethod
def attach_from_env(cls):
def attach_from_env(cls) -> "TraceContextCarrier":
traceparent = os.environ.get(cls.traceparent_var)
carrier = TraceContextCarrier(carrier={"traceparent": traceparent} if traceparent is not None else {})
carrier.attach()
return carrier

@classmethod
def attach_from_options(cls, options: TelemetryOptions):
def attach_from_options(cls, options: TelemetryOptions) -> "TraceContextCarrier":
traceparent = options.TRACEPARENT
carrier = TraceContextCarrier(
carrier={"traceparent": traceparent} if traceparent is not None else {}
)
carrier = TraceContextCarrier(carrier={"traceparent": traceparent} if traceparent is not None else {})
carrier.attach()
return carrier

@classmethod
def inject_to_env(cls):
def inject_to_env(cls) -> None:
ctx = TraceContextCarrier()
if "traceparent" in ctx.carrier:
os.environ[cls.traceparent_var] = ctx.carrier["traceparent"]

def attach(self):
def attach(self) -> None:
self.token = self.__attach(self.carrier)

def detach(self):
def detach(self) -> None:
if self.token is not None:
context.detach(self.token)
self.token = None

def __eq__(self, other):
return self.carrier == other.carrier
def __eq__(self, other: object) -> bool:
return isinstance(other, TraceContextCarrier) and self.carrier == other.carrier

@classmethod
def __attach(cls, carrier):
def __attach(cls, carrier: Dict[str, str]) -> object:
token = context.attach(TraceContextTextMapPropagator().extract(carrier=carrier))
return token


class TraceEventLogHandler(logging.StreamHandler):
if TYPE_CHECKING:
BaseStreamHandler = logging.StreamHandler["TraceEventLogHandler"]
else:
BaseStreamHandler = logging.StreamHandler


class TraceEventLogHandler(BaseStreamHandler):
"""log handler class that adds log messages as events in the current span"""

def __init__(self):
def __init__(self) -> None:
super().__init__(stream=self)
self.name = "TraceEventLogHandler"

def write(self, msg: str):
def write(self, msg: str) -> None:
if msg != self.terminator:
current_span = trace.get_current_span()
current_span.add_event(msg)

def flush(self):
def flush(self) -> None:
"""no need to flush"""


def get_tracer(module_name: str, service_name: str = None):
def get_tracer(module_name: str, service_name: Optional[str] = None) -> Tracer:
"""
Get the `Tracer` for the specified module and service name
Args:
Expand All @@ -143,7 +152,7 @@ def get_tracer(module_name: str, service_name: str = None):
return trace.get_tracer(module_name, tracer_provider=tracer_provider)


def init_telemetry_provider(options: TelemetryOptions = None, **resource_attrs):
def init_telemetry_provider(options: Optional[TelemetryOptions] = None, **resource_attrs: Any) -> None:
"""
Initialize telemetry collection for a service, and inherits any trace context
set from the TRACEPARENT environment variable
Expand All @@ -167,7 +176,7 @@ def init_telemetry_provider(options: TelemetryOptions = None, **resource_attrs):
TraceContextCarrier.attach_from_env()


def flush_telemetry_data():
def flush_telemetry_data() -> None:
"""Forces a flush of all span exporters attached to trace providers"""
global global_tracer_provider, tracer_providers_by_service_name
if global_tracer_provider is not None:
Expand All @@ -177,7 +186,7 @@ def flush_telemetry_data():
provider.force_flush() # noqa


def _try_load_trace_provider(options: TelemetryOptions, **resource_attrs):
def _try_load_trace_provider(options: TelemetryOptions, **resource_attrs: Any) -> None:
global global_tracer_provider, tracer_providers_by_service_name
try:
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
Expand Down Expand Up @@ -230,16 +239,20 @@ def _try_load_trace_provider(options: TelemetryOptions, **resource_attrs):
pass


def _get_traces_endpoint(options: TelemetryOptions):
path = "v1/traces" if options.OTEL_EXPORTER_OTLP_ENDPOINT.endswith("/") else "/v1/traces"
def _get_traces_endpoint(options: TelemetryOptions) -> str:
path = (
"v1/traces"
if options.OTEL_EXPORTER_OTLP_ENDPOINT and options.OTEL_EXPORTER_OTLP_ENDPOINT.endswith("/")
else "/v1/traces"
)
endpoint = f"{options.OTEL_EXPORTER_OTLP_ENDPOINT}{path}"
return endpoint


class ContextInjector:
def __call__(self, wrapped_function: Callable) -> Callable:
def __call__(self, wrapped_function: CallableType) -> CallableType:
@wraps(wrapped_function)
def new_f(*args, **kwargs):
def new_f(*args: Any, **kwargs: Any) -> Any:
prev_env = os.environ.get(TraceContextCarrier.traceparent_var)
TraceContextCarrier.inject_to_env()
try:
Expand All @@ -251,23 +264,23 @@ def new_f(*args, **kwargs):
return new_f


def inject_context_to_env(wrapped_function: Callable):
def inject_context_to_env(wrapped_function: CallableType) -> CallableType:
injector = ContextInjector()
return injector(wrapped_function)


class Instrumented:
def __init__(
self,
span_name: str = None,
service_name: str = None,
span_name: Optional[str] = None,
service_name: Optional[str] = None,
span_attributes: Optional[Dict[str, SpanAttributeValue]] = None,
):
) -> None:
self.span_name = span_name
self.service_name = service_name
self.span_attributes = span_attributes if span_attributes is not None else {}

def __call__(self, wrapped_function: Callable) -> Callable:
def __call__(self, wrapped_function: DecoratedFuncType) -> DecoratedFuncType:
module = inspect.getmodule(wrapped_function)
is_async = inspect.iscoroutinefunction(wrapped_function)
module_name = __name__
Expand All @@ -276,27 +289,42 @@ def __call__(self, wrapped_function: Callable) -> Callable:
span_name = self.span_name or wrapped_function.__qualname__

@wraps(wrapped_function)
def new_f(*args, **kwargs):
def new_f(*args: Any, **kwargs: Any) -> Any:
with get_tracer(module_name, service_name=self.service_name).start_as_current_span(span_name) as span:
span.set_attributes(self.span_attributes)
return wrapped_function(*args, **kwargs)

@wraps(wrapped_function)
async def new_f_async(*args, **kwargs):
async def new_f_async(*args: Any, **kwargs: Any) -> Any:
with get_tracer(module_name, service_name=self.service_name).start_as_current_span(span_name) as span:
span.set_attributes(self.span_attributes)
return await wrapped_function(*args, **kwargs)

return new_f_async if is_async else new_f
return cast(DecoratedFuncType, new_f_async) if is_async else cast(DecoratedFuncType, new_f)


@overload
def instrumented(wrapped_function: DecoratedFuncType) -> DecoratedFuncType:
...


@overload
def instrumented(
*,
span_name: Optional[str] = None,
service_name: Optional[str] = None,
span_attributes: Optional[Dict[str, SpanAttributeValue]] = None,
) -> Instrumented:
...


def instrumented(
wrapped_function: Optional[Callable] = None,
wrapped_function: Optional[DecoratedFuncType] = None,
*,
span_name: Optional[str] = None,
service_name: Optional[str] = None,
span_attributes: Optional[Dict[str, SpanAttributeValue]] = None,
):
) -> Union[DecoratedFuncType, Instrumented]:
"""
Decorator to enable opentelemetry instrumentation on a function.
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ classifiers =
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Topic :: Software Development :: Libraries
Topic :: Software Development :: Testing
Topic :: Utilities
Expand All @@ -35,7 +35,7 @@ packages = find:
install_requires =
opentelemetry-api
opentelemetry-sdk
python_requires = >=3.6
python_requires = >=3.7
zip_safe = True

[options.packages.find]
Expand Down Expand Up @@ -81,7 +81,7 @@ xfail_strict = True
junit_family = xunit2

[mypy]
python_version = 3.6
python_version = 3.7
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True
Expand Down
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-cov
opentelemetry-distro[otlp]
pytest-asyncio
mypy
Loading

0 comments on commit a5cc393

Please sign in to comment.