Skip to content

Commit

Permalink
Create and clean up kafka topics manually to reduce memory burden for…
Browse files Browse the repository at this point in the history
… Adala server (#107)
  • Loading branch information
matt-bernstein authored May 21, 2024
1 parent 33b2a9b commit d163a69
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 31 deletions.
4 changes: 3 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
- ALLOW_PLAINTEXT_LISTENER=yes
- KAFKA_CFG_NODE_ID=1
- KAFKA_KRAFT_CLUSTER_ID=MkU3OEVBNTcwNTJENDM2Qk
- KAFKA_CFG_AUTO_CREATE_TOPICS_ENABLE=false
app:
build:
context: .
Expand All @@ -30,8 +31,9 @@ services:
redis:
condition: service_healthy
environment:
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- REDIS_URL=redis://redis:6379/0
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093 # TODO pull from .env
- KAFKA_RETENTION_MS=180000 # TODO pull from .env
command:
["poetry", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
worker:
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"
kafka-python = "^2.0.2"

[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
3 changes: 3 additions & 0 deletions server/.env.example
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
KAFKA_BOOTSTRAP_SERVERS="localhost:9093"

# this value is only for local dev. In our deployments, it is not set here, but in another place: https://github.com/HumanSignal/infra/pull/67
KAFKA_RETENTION_MS=180000 # 30 minutes
4 changes: 2 additions & 2 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
process_streaming_output,
streaming_parent_task,
)
from utils import get_input_topic, Settings
from utils import get_input_topic_name, Settings
from server.handlers.result_handlers import ResultHandler


Expand Down Expand Up @@ -230,7 +230,7 @@ async def submit_batch(batch: BatchData):
Response: Generic response indicating status of request
"""

topic = get_input_topic(batch.job_id)
topic = get_input_topic_name(batch.job_id)
producer = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
Expand Down
56 changes: 33 additions & 23 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, Settings
from server.utils import (
get_input_topic_name,
get_output_topic_name,
ensure_topic,
delete_topic,
Settings,
)
from server.handlers.result_handlers import ResultHandler


Expand Down Expand Up @@ -54,21 +60,29 @@ def streaming_parent_task(
# Parent job ID is used for input/output topic names
parent_job_id = self.request.id

# Override kafka_bootstrap_servers with value from settings
# create kafka topics
input_topic_name = get_input_topic_name(parent_job_id)
ensure_topic(input_topic_name)
output_topic_name = get_output_topic_name(parent_job_id)
ensure_topic(output_topic_name)

# Override default agent kafka settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers
agent.environment.kafka_input_topic = input_topic_name
agent.environment.kafka_output_topic = output_topic_name

inference_task = process_file_streaming
logger.info(f"Submitting task {inference_task.name} with agent {agent}")
input_result = inference_task.delay(agent=agent, parent_job_id=parent_job_id)
input_result = inference_task.delay(agent=agent)
input_job_id = input_result.id
logger.info(f"Task {inference_task.name} submitted with job_id {input_job_id}")

result_handler_task = process_streaming_output
logger.info(f"Submitting task {result_handler_task.name}")
output_result = result_handler_task.delay(
input_job_id=input_job_id,
parent_job_id=parent_job_id,
output_topic_name=output_topic_name,
result_handler=result_handler,
batch_size=batch_size,
)
Expand All @@ -95,6 +109,10 @@ def streaming_parent_task(
):
time.sleep(1)

# clean up kafka topics
delete_topic(input_topic_name)
delete_topic(output_topic_name)

logger.info("Both input and output jobs complete")

# Update parent task status to SUCCESS and pass metadata again
Expand All @@ -109,45 +127,40 @@ def streaming_parent_task(
raise Ignore()


@app.task(
name="process_file_streaming", track_started=True, bind=True, serializer="pickle"
)
def process_file_streaming(self, agent: Agent, parent_job_id: str):
# Set input and output topics using parent job ID
agent.environment.kafka_input_topic = get_input_topic(parent_job_id)
agent.environment.kafka_output_topic = get_output_topic(parent_job_id)
@app.task(name="process_file_streaming", track_started=True, serializer="pickle")
def process_file_streaming(agent: Agent):
# agent's kafka_bootstrap servers and kafka topics should be set in parent task

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(
input_job_id: str,
parent_job_id: str,
output_topic_name,
result_handler: ResultHandler,
batch_size: int,
):
logger.info(f"Polling for results {parent_job_id=}")
logger.info(f"Polling for results {output_topic_name=}")

topic = get_output_topic(parent_job_id)
settings = Settings()

# Retry to workaround race condition of topic creation
retries = 5
while retries > 0:
try:
consumer = AIOKafkaConsumer(
topic,
output_topic_name,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {parent_job_id=}")
logger.info(f"consumer started {output_topic_name=}")
break
except UnknownTopicOrPartitionError as e:
logger.error(msg=e)
logger.info(f"Retrying to create consumer with topic {topic}")
logger.info(f"Retrying to create consumer with topic {output_topic_name}")

await consumer.stop()
retries -= 1
Expand Down Expand Up @@ -183,20 +196,17 @@ async def async_process_streaming_output(
await consumer.stop()


@app.task(
name="process_streaming_output", track_started=True, bind=True, serializer="pickle"
)
@app.task(name="process_streaming_output", track_started=True, serializer="pickle")
def process_streaming_output(
self,
input_job_id: str,
parent_job_id: str,
output_topic_name: str,
result_handler: ResultHandler,
batch_size: int,
):
try:
asyncio.run(
async_process_streaming_output(
input_job_id, parent_job_id, result_handler, batch_size
input_job_id, output_topic_name, result_handler, batch_size
)
)
except Exception as e:
Expand Down
49 changes: 45 additions & 4 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Union
from pathlib import Path
from kafka.admin import KafkaAdminClient, NewTopic
from kafka.errors import TopicAlreadyExistsError


class Settings(BaseSettings):
Expand All @@ -10,16 +12,55 @@ class Settings(BaseSettings):
"""

kafka_bootstrap_servers: Union[str, List[str]]
kafka_retention_ms: int

model_config = SettingsConfigDict(
# have to use an absolute path here so celery workers can find it
env_file=(Path(__file__).parent / ".env"),
)


def get_input_topic(job_id: str):
return f"adala-input-{job_id}"
def get_input_topic_name(job_id: str):
topic_name = f"adala-input-{job_id}"

return topic_name

def get_output_topic(job_id: str):
return f"adala-output-{job_id}"

def get_output_topic_name(job_id: str):
topic_name = f"adala-output-{job_id}"

return topic_name


def ensure_topic(topic_name: str):
settings = Settings()
bootstrap_servers = settings.kafka_bootstrap_servers
retention_ms = settings.kafka_retention_ms

admin_client = KafkaAdminClient(
bootstrap_servers=bootstrap_servers, client_id="topic_creator"
)

topic = NewTopic(
name=topic_name,
num_partitions=1,
replication_factor=1,
topic_configs={"retention.ms": str(retention_ms)},
)

try:
admin_client.create_topics(new_topics=[topic])
except TopicAlreadyExistsError:
# we shouldn't hit this case when KAFKA_CFG_AUTO_CREATE_TOPICS=false unless there is a legitimate name collision, so should raise here after testing
pass


def delete_topic(topic_name: str):
settings = Settings()
bootstrap_servers = settings.kafka_bootstrap_servers

admin_client = KafkaAdminClient(
bootstrap_servers=bootstrap_servers, client_id="topic_deleter"
)

admin_client.delete_topics(topics=[topic_name])

0 comments on commit d163a69

Please sign in to comment.