diff --git a/instrumentation/opentelemetry-instrumentation-aiokafka/tests/test_instrumentation.py b/instrumentation/opentelemetry-instrumentation-aiokafka/tests/test_instrumentation.py index 43d1eac508..92739ad554 100644 --- a/instrumentation/opentelemetry-instrumentation-aiokafka/tests/test_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-aiokafka/tests/test_instrumentation.py @@ -13,7 +13,7 @@ # limitations under the License. import uuid -from typing import List, Tuple +from typing import List, Sequence, Tuple from unittest import IsolatedAsyncioTestCase, mock from aiokafka import ( @@ -24,12 +24,13 @@ ) from wrapt import BoundFunctionWrapper +from opentelemetry import baggage, context from opentelemetry.instrumentation.aiokafka import AIOKafkaInstrumentor -from opentelemetry.sdk.trace import Span +from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.semconv._incubating.attributes import messaging_attributes from opentelemetry.semconv.attributes import server_attributes from opentelemetry.test.test_base import TestBase -from opentelemetry.trace import SpanKind, format_trace_id +from opentelemetry.trace import SpanKind, format_trace_id, set_span_in_context class TestAIOKafka(TestBase, IsolatedAsyncioTestCase): @@ -51,6 +52,19 @@ def consumer_record_factory( headers=headers, ) + @staticmethod + def producer_factory() -> AIOKafkaProducer: + producer = AIOKafkaProducer(api_version="1.0") + + add_message_mock = mock.AsyncMock() + producer.client._wait_on_metadata = mock.AsyncMock() + producer.client.bootstrap = mock.AsyncMock() + producer._message_accumulator.add_message = add_message_mock + producer._sender.start = mock.AsyncMock() + producer._partition = mock.Mock(return_value=1) + + return producer + def test_instrument_api(self) -> None: instrumentation = AIOKafkaInstrumentor() @@ -147,7 +161,46 @@ async def test_anext(self) -> None: span_list = self.memory_exporter.get_finished_spans() self._compare_spans(span_list, expected_spans) - async def test_anext_consumer_hook(self) -> None: + async def test_anext_baggage(self) -> None: + received_baggage = None + + async def async_consume_hook(span, *_) -> None: + nonlocal received_baggage + received_baggage = baggage.get_all(set_span_in_context(span)) + + AIOKafkaInstrumentor().uninstrument() + AIOKafkaInstrumentor().instrument( + tracer_provider=self.tracer_provider, + async_consume_hook=async_consume_hook, + ) + + consumer = AIOKafkaConsumer() + + self.memory_exporter.clear() + + getone_mock = mock.AsyncMock() + consumer.getone = getone_mock + + getone_mock.side_effect = [ + self.consumer_record_factory( + 1, + headers=( + ( + "traceparent", + b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01", + ), + ("baggage", b"foo=bar"), + ), + ), + self.consumer_record_factory(2, headers=()), + ] + + await consumer.__anext__() + getone_mock.assert_awaited_with() + + self.assertEqual(received_baggage, {"foo": "bar"}) + + async def test_anext_consume_hook(self) -> None: async_consume_hook_mock = mock.AsyncMock() AIOKafkaInstrumentor().uninstrument() @@ -171,14 +224,10 @@ async def test_send(self) -> None: AIOKafkaInstrumentor().uninstrument() AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider) - producer = AIOKafkaProducer(api_version="1.0") - - add_message_mock = mock.AsyncMock() - producer.client._wait_on_metadata = mock.AsyncMock() - producer.client.bootstrap = mock.AsyncMock() - producer._message_accumulator.add_message = add_message_mock - producer._sender.start = mock.AsyncMock() - producer._partition = mock.Mock(return_value=1) + producer = self.producer_factory() + add_message_mock: mock.AsyncMock = ( + producer._message_accumulator.add_message + ) await producer.start() @@ -208,6 +257,33 @@ async def test_send(self) -> None: headers=[("traceparent", mock.ANY)], ) + async def test_send_baggage(self) -> None: + AIOKafkaInstrumentor().uninstrument() + AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider) + + producer = self.producer_factory() + add_message_mock: mock.AsyncMock = ( + producer._message_accumulator.add_message + ) + + await producer.start() + + tracer = self.tracer_provider.get_tracer(__name__) + ctx = baggage.set_baggage("foo", "bar") + context.attach(ctx) + + with tracer.start_as_current_span("test_span", context=ctx): + await producer.send("topic_1", b"value_1") + + add_message_mock.assert_awaited_with( + TopicPartition(topic="topic_1", partition=1), + None, + b"value_1", + 40.0, + timestamp_ms=None, + headers=[("traceparent", mock.ANY), ("baggage", b"foo=bar")], + ) + async def test_send_produce_hook(self) -> None: async_produce_hook_mock = mock.AsyncMock() @@ -217,13 +293,7 @@ async def test_send_produce_hook(self) -> None: async_produce_hook=async_produce_hook_mock, ) - producer = AIOKafkaProducer(api_version="1.0") - - producer.client._wait_on_metadata = mock.AsyncMock() - producer.client.bootstrap = mock.AsyncMock() - producer._message_accumulator.add_message = mock.AsyncMock() - producer._sender.start = mock.AsyncMock() - producer._partition = mock.Mock(return_value=1) + producer = self.producer_factory() await producer.start() @@ -232,7 +302,7 @@ async def test_send_produce_hook(self) -> None: async_produce_hook_mock.assert_awaited_once() def _compare_spans( - self, spans: List[Span], expected_spans: List[dict] + self, spans: Sequence[ReadableSpan], expected_spans: List[dict] ) -> None: self.assertEqual(len(spans), len(expected_spans)) for span, expected_span in zip(spans, expected_spans):