diff --git a/python/restate/ext/tracing/_tracing.py b/python/restate/ext/tracing/_tracing.py index b272185..198fc93 100644 --- a/python/restate/ext/tracing/_tracing.py +++ b/python/restate/ext/tracing/_tracing.py @@ -9,8 +9,13 @@ # All spans created by this tracer are flat children of the Restate trace. """ -from opentelemetry.trace import INVALID_SPAN, use_span, Tracer, TracerProvider +from typing import Optional, Iterator, Sequence + +from opentelemetry import context as context_api +from opentelemetry.trace import INVALID_SPAN, Span, SpanKind, Tracer, TracerProvider, use_span, Link from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry.util import types +from opentelemetry.util._decorator import _agnosticcontextmanager from restate.server_context import ( current_context, get_extension_data, @@ -58,15 +63,15 @@ def _get_root_context(): def start_span( self, - name, - context=None, - kind=None, - attributes=None, - links=None, - start_time=None, - record_exception=True, - set_status_on_exception=True, - ): + name: str, + context: Optional[context_api.Context] = None, + kind: SpanKind = SpanKind.INTERNAL, + attributes: types.Attributes = None, + links: Optional[Sequence[Link]] = None, + start_time: Optional[int] = None, + record_exception: bool = True, + set_status_on_exception: bool = True, + ) -> Span: if restate_context_is_replaying.get(False): return INVALID_SPAN root = self._get_root_context() @@ -85,34 +90,38 @@ def start_span( self._track_span(span) return span + @_agnosticcontextmanager def start_as_current_span( self, - name, - context=None, - kind=None, - attributes=None, - links=None, - start_time=None, - record_exception=True, - set_status_on_exception=True, - end_on_exit=True, - ): + name: str, + context: Optional[context_api.Context] = None, + kind: SpanKind = SpanKind.INTERNAL, + attributes: types.Attributes = None, + links: Optional[Sequence[Link]] = None, + start_time: Optional[int] = None, + record_exception: bool = True, + set_status_on_exception: bool = True, + end_on_exit: bool = True, + ) -> Iterator[Span]: if restate_context_is_replaying.get(False): - return use_span(INVALID_SPAN, end_on_exit=False) - root = self._get_root_context() - if root is not None: - context = root - return self._tracer.start_as_current_span( - name, - context=context, - kind=kind, - attributes=attributes, - links=links, - start_time=start_time, - record_exception=record_exception, - set_status_on_exception=set_status_on_exception, - end_on_exit=end_on_exit, - ) + with use_span(INVALID_SPAN, end_on_exit=False) as span: + yield span + else: + root = self._get_root_context() + if root is not None: + context = root + with self._tracer.start_as_current_span( + name, + context=context, + kind=kind, + attributes=attributes, + links=links, + start_time=start_time, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + end_on_exit=end_on_exit, + ) as span: + yield span @staticmethod def _track_span(span):