Skip to content

Commit

Permalink
wrap getone instead of __anext__
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Sep 17, 2024
1 parent 5890003 commit 5372078
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def async_consume_hook(span, record, args, kwargs):
from opentelemetry import trace
from opentelemetry.instrumentation.aiokafka.package import _instruments
from opentelemetry.instrumentation.aiokafka.utils import (
_wrap_anext,
_wrap_getone,
_wrap_send,
)
from opentelemetry.instrumentation.aiokafka.version import __version__
Expand Down Expand Up @@ -126,10 +126,10 @@ def _instrument(self, **kwargs):
)
wrap_function_wrapper(
aiokafka.AIOKafkaConsumer,
"__anext__",
_wrap_anext(tracer, async_consume_hook),
"getone",
_wrap_getone(tracer, async_consume_hook),
)

def _uninstrument(self, **kwargs):
unwrap(aiokafka.AIOKafkaProducer, "send")
unwrap(aiokafka.AIOKafkaConsumer, "__anext__")
unwrap(aiokafka.AIOKafkaConsumer, "getone")
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ async def _create_consumer_span(
context.detach(token)


def _wrap_anext(
def _wrap_getone(
tracer: Tracer, async_consume_hook: ConsumeHookT
) -> Callable[..., Awaitable[aiokafka.ConsumerRecord]]:
async def _traced_next(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import uuid
from typing import List, Sequence, Tuple
from typing import Any, List, Sequence, Tuple
from unittest import IsolatedAsyncioTestCase, mock

from aiokafka import (
Expand Down Expand Up @@ -53,16 +53,30 @@ def consumer_record_factory(
)

@staticmethod
def producer_factory() -> AIOKafkaProducer:
async def consumer_factory(**consumer_kwargs: Any) -> AIOKafkaConsumer:
consumer = AIOKafkaConsumer(**consumer_kwargs)

consumer._client.bootstrap = mock.AsyncMock()
consumer._client._wait_on_metadata = mock.AsyncMock()

await consumer.start()

consumer._fetcher.next_record = mock.AsyncMock()

return consumer

@staticmethod
async def producer_factory() -> AIOKafkaProducer:
producer = AIOKafkaProducer(api_version="1.0")

add_message_mock = mock.AsyncMock()
producer.client._wait_on_metadata = mock.AsyncMock()
producer.client.bootstrap = mock.AsyncMock()
producer._message_accumulator.add_message = add_message_mock
producer._message_accumulator.add_message = mock.AsyncMock()
producer._sender.start = mock.AsyncMock()
producer._partition = mock.Mock(return_value=1)

await producer.start()

return producer

def test_instrument_api(self) -> None:
Expand All @@ -73,24 +87,27 @@ def test_instrument_api(self) -> None:
isinstance(AIOKafkaProducer.send, BoundFunctionWrapper)
)
self.assertTrue(
isinstance(AIOKafkaConsumer.__anext__, BoundFunctionWrapper)
isinstance(AIOKafkaConsumer.getone, BoundFunctionWrapper)
)

instrumentation.uninstrument()
self.assertFalse(
isinstance(AIOKafkaProducer.send, BoundFunctionWrapper)
)
self.assertFalse(
isinstance(AIOKafkaConsumer.__anext__, BoundFunctionWrapper)
isinstance(AIOKafkaConsumer.getone, BoundFunctionWrapper)
)

async def test_anext(self) -> None:
async def test_getone(self) -> None:
AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider)

client_id = str(uuid.uuid4())
group_id = str(uuid.uuid4())
consumer = AIOKafkaConsumer(client_id=client_id, group_id=group_id)
consumer = await self.consumer_factory(
client_id=client_id, group_id=group_id
)
next_record_mock: mock.AsyncMock = consumer._fetcher.next_record

expected_spans = [
{
Expand Down Expand Up @@ -130,10 +147,7 @@ async def test_anext(self) -> None:
]
self.memory_exporter.clear()

getone_mock = mock.AsyncMock()
consumer.getone = getone_mock

getone_mock.side_effect = [
next_record_mock.side_effect = [
self.consumer_record_factory(
1,
headers=(
Expand All @@ -146,22 +160,22 @@ async def test_anext(self) -> None:
self.consumer_record_factory(2, headers=()),
]

await consumer.__anext__()
getone_mock.assert_awaited_with()
await consumer.getone()
next_record_mock.assert_awaited_with(())

first_span = self.memory_exporter.get_finished_spans()[0]
self.assertEqual(
format_trace_id(first_span.get_span_context().trace_id),
"03afa25236b8cd948fa853d67038ac79",
)

await consumer.__anext__()
getone_mock.assert_awaited_with()
await consumer.getone()
next_record_mock.assert_awaited_with(())

span_list = self.memory_exporter.get_finished_spans()
self._compare_spans(span_list, expected_spans)

async def test_anext_baggage(self) -> None:
async def test_getone_baggage(self) -> None:
received_baggage = None

async def async_consume_hook(span, *_) -> None:
Expand All @@ -174,14 +188,12 @@ async def async_consume_hook(span, *_) -> None:
async_consume_hook=async_consume_hook,
)

consumer = AIOKafkaConsumer()
consumer = await self.consumer_factory()
next_record_mock: mock.AsyncMock = consumer._fetcher.next_record

self.memory_exporter.clear()

getone_mock = mock.AsyncMock()
consumer.getone = getone_mock

getone_mock.side_effect = [
next_record_mock.side_effect = [
self.consumer_record_factory(
1,
headers=(
Expand All @@ -194,12 +206,12 @@ async def async_consume_hook(span, *_) -> None:
),
]

await consumer.__anext__()
getone_mock.assert_awaited_with()
await consumer.getone()
next_record_mock.assert_awaited_with(())

self.assertEqual(received_baggage, {"foo": "bar"})

async def test_anext_consume_hook(self) -> None:
async def test_getone_consume_hook(self) -> None:
async_consume_hook_mock = mock.AsyncMock()

AIOKafkaInstrumentor().uninstrument()
Expand All @@ -208,28 +220,26 @@ async def test_anext_consume_hook(self) -> None:
async_consume_hook=async_consume_hook_mock,
)

consumer = AIOKafkaConsumer()

getone_mock = mock.AsyncMock()
consumer.getone = getone_mock
consumer = await self.consumer_factory()
next_record_mock: mock.AsyncMock = consumer._fetcher.next_record

getone_mock.side_effect = [self.consumer_record_factory(1, headers=())]
next_record_mock.side_effect = [
self.consumer_record_factory(1, headers=())
]

await consumer.__anext__()
await consumer.getone()

async_consume_hook_mock.assert_awaited_once()

async def test_send(self) -> None:
AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider)

producer = self.producer_factory()
producer = await self.producer_factory()
add_message_mock: mock.AsyncMock = (
producer._message_accumulator.add_message
)

await producer.start()

tracer = self.tracer_provider.get_tracer(__name__)
with tracer.start_as_current_span("test_span") as span:
await producer.send("topic_1", b"value_1")
Expand Down Expand Up @@ -260,13 +270,11 @@ async def test_send_baggage(self) -> None:
AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider)

producer = self.producer_factory()
producer = await self.producer_factory()
add_message_mock: mock.AsyncMock = (
producer._message_accumulator.add_message
)

await producer.start()

tracer = self.tracer_provider.get_tracer(__name__)
ctx = baggage.set_baggage("foo", "bar")
context.attach(ctx)
Expand All @@ -292,9 +300,7 @@ async def test_send_produce_hook(self) -> None:
async_produce_hook=async_produce_hook_mock,
)

producer = self.producer_factory()

await producer.start()
producer = await self.producer_factory()

await producer.send("topic_1", b"value_1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_create_consumer_span,
_extract_send_partition,
_get_span_name,
_wrap_anext,
_wrap_getone,
_wrap_send,
)
from opentelemetry.trace import SpanKind
Expand Down Expand Up @@ -187,7 +187,7 @@ async def test_wrap_next(
original_next_callback = mock.AsyncMock()
kafka_consumer = mock.MagicMock()

wrapped_next = _wrap_anext(tracer, consume_hook)
wrapped_next = _wrap_getone(tracer, consume_hook)
record = await wrapped_next(
original_next_callback, kafka_consumer, self.args, self.kwargs
)
Expand Down

0 comments on commit 5372078

Please sign in to comment.