Skip to content

Commit

Permalink
fix: DIA-1584: Use send_and_wait + batches for output topic (#245)
Browse files Browse the repository at this point in the history
Co-authored-by: hakan458 <[email protected]>
  • Loading branch information
hakan458 and hakan458 authored Nov 7, 2024
1 parent c6da54b commit 0ac4d2e
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 65 deletions.
18 changes: 6 additions & 12 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ async def initialize(self):
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
# enable_auto_commit=False, # Turned off as its not supported without group ID
# group_id=output_topic_name, # No longer using group ID as of DIA-1584 - unclear details but causes problems
)
await self.consumer.start()

Expand Down Expand Up @@ -95,12 +95,9 @@ async def message_sender(
):
record_no = 0
try:
for record in data:
await producer.send(topic, value=record)
record_no += 1
# print_text(f"Sent message: {record} to {topic=}")
await producer.send_and_wait(topic, value=data)
logger.info(
f"The number of records sent to topic:{topic}, record_no:{record_no}"
f"The number of records sent to topic:{topic}, record_no:{len(data)}"
)
finally:
pass
Expand All @@ -110,7 +107,6 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
batch = await self.consumer.getmany(
timeout_ms=self.timeout_ms, max_records=batch_size
)
await self.consumer.commit()

if len(batch) == 0:
batch_data = []
Expand All @@ -129,7 +125,5 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
return InternalDataFrame(batch_data)

async def set_predictions(self, predictions: InternalDataFrame):
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
await self.message_sender(
self.producer, predictions_iter, self.kafka_output_topic
)
predictions = [r.to_dict() for _, r in predictions.iterrows()]
await self.message_sender(self.producer, predictions, self.kafka_output_topic)
71 changes: 36 additions & 35 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gspread = "^5.12.3"
datasets = "^2.16.1"
aiohttp = "^3.9.3"
boto3 = "^1.34.38"
aiokafka = "^0.10.0"
aiokafka = "^0.11.0"
# these are for the server
# they would be installed as `extras` if poetry supported version strings for extras, but it doesn't
# https://github.com/python-poetry/poetry/issues/834
Expand Down
27 changes: 15 additions & 12 deletions server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def run_streaming(
task_time_limit=settings.task_time_limit_sec,
)
def streaming_parent_task(
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 10
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 1
):
"""
This task is used to launch the two tasks that are doing the real work, so that
Expand Down Expand Up @@ -140,9 +140,9 @@ async def async_process_streaming_output(
output_topic_name,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=output_topic_name, # ensuring unique group_id to not mix up offsets between topics
# enable_auto_commit=False, # Turned off as its not supported without group ID
# group_id=output_topic_name, # No longer using group ID as of DIA-1584 - unclear details but causes problems
)
await consumer.start()
logger.info(f"consumer started {output_topic_name=}")
Expand All @@ -158,18 +158,21 @@ async def async_process_streaming_output(
try:
while not input_done.is_set():
data = await consumer.getmany(timeout_ms=timeout_ms, max_records=batch_size)
await consumer.commit()
for topic_partition, messages in data.items():
topic = topic_partition.topic
# messages is a list of ConsumerRecord
if messages:
logger.info(
f"Processing messages in output job {topic=} number of messages: {len(messages)}"
)
data = [msg.value for msg in messages]
result_handler(data)
logger.info(
f"Processed messages in output job {topic=} number of messages: {len(messages)}"
)
# batches is a list of lists
batches = [msg.value for msg in messages]
# records is a list of records to send to LSE
for records in batches:
logger.info(
f"Processing messages in output job {topic=} number of messages: {len(records)}"
)
result_handler(records)
logger.info(
f"Processed messages in output job {topic=} number of messages: {len(records)}"
)
else:
logger.info(f"Consumer pulled data, but no messages in {topic=}")

Expand Down
9 changes: 4 additions & 5 deletions tests/test_stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def getmany_side_effect(*args, **kwargs):
await PRODUCER_SENT_DATA.wait()
return {
AsyncMock(topic="output_topic_partition"): [
AsyncMock(value=row) for row in TEST_OUTPUT_DATA
AsyncMock(value=TEST_OUTPUT_DATA)
]
}

Expand Down Expand Up @@ -159,11 +159,10 @@ async def test_run_streaming(
await run_streaming(
agent=agent,
result_handler=result_handler,
batch_size=10,
batch_size=1,
output_topic_name="output_topic",
)

# Verify that producer is called with the correct amount of send_and_wait calls and data
assert mock_kafka_producer.send.call_count == 1
for row in TEST_OUTPUT_DATA:
mock_kafka_producer.send.assert_any_call("output_topic", value=row)
assert mock_kafka_producer.send_and_wait.call_count == 1
mock_kafka_producer.send_and_wait.assert_any_call("output_topic", value=TEST_OUTPUT_DATA)

0 comments on commit 0ac4d2e

Please sign in to comment.