Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref(rules): refactor delayed processing batching logic to prepare for workflows #83670

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/sentry/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from sentry.tasks.process_buffer import process_incr
from sentry.utils.services import Service

BufferField = models.Model | str | int


class Buffer(Service):
"""
Expand Down Expand Up @@ -50,14 +52,10 @@ def get(
"""
return {col: 0 for col in columns}

def get_hash(
self, model: type[models.Model], field: dict[str, models.Model | str | int]
) -> dict[str, str]:
def get_hash(self, model: type[models.Model], field: dict[str, BufferField]) -> dict[str, str]:
return {}

def get_hash_length(
self, model: type[models.Model], field: dict[str, models.Model | str | int]
) -> int:
def get_hash_length(self, model: type[models.Model], field: dict[str, BufferField]) -> int:
raise NotImplementedError

def get_sorted_set(self, key: str, min: float, max: float) -> list[tuple[int, datetime]]:
Expand All @@ -69,7 +67,7 @@ def push_to_sorted_set(self, key: str, value: list[int] | int) -> None:
def push_to_hash(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
field: str,
value: str,
) -> None:
Expand All @@ -78,15 +76,15 @@ def push_to_hash(
def push_to_hash_bulk(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
data: dict[str, str],
) -> None:
raise NotImplementedError

def delete_hash(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
fields: list[str],
) -> None:
return None
Expand All @@ -98,7 +96,7 @@ def incr(
self,
model: type[models.Model],
columns: dict[str, int],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
extra: dict[str, Any] | None = None,
signal_only: bool | None = None,
) -> None:
Expand Down
20 changes: 8 additions & 12 deletions src/sentry/buffer/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from django.utils.encoding import force_bytes, force_str
from rediscluster import RedisCluster

from sentry.buffer.base import Buffer
from sentry.buffer.base import Buffer, BufferField
from sentry.db import models
from sentry.tasks.process_buffer import process_incr
from sentry.utils import json, metrics
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, incr_batch_size: int = 2, **options: object):
def validate(self) -> None:
validate_dynamic_cluster(self.is_redis_cluster, self.cluster)

def _coerce_val(self, value: models.Model | str | int) -> bytes:
def _coerce_val(self, value: BufferField) -> bytes:
if isinstance(value, models.Model):
value = value.pk
return force_bytes(value, errors="replace")
Expand Down Expand Up @@ -395,7 +395,7 @@ def delete_key(self, key: str, min: float, max: float) -> None:
def delete_hash(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
fields: list[str],
) -> None:
key = self._make_key(model, filters)
Expand All @@ -408,7 +408,7 @@ def delete_hash(
def push_to_hash(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
field: str,
value: str,
) -> None:
Expand All @@ -418,15 +418,13 @@ def push_to_hash(
def push_to_hash_bulk(
self,
model: type[models.Model],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
data: dict[str, str],
) -> None:
key = self._make_key(model, filters)
self._execute_redis_operation(key, RedisOperation.HASH_ADD_BULK, data)

def get_hash(
self, model: type[models.Model], field: dict[str, models.Model | str | int]
) -> dict[str, str]:
def get_hash(self, model: type[models.Model], field: dict[str, BufferField]) -> dict[str, str]:
key = self._make_key(model, field)
redis_hash = self._execute_redis_operation(key, RedisOperation.HASH_GET_ALL)
decoded_hash = {}
Expand All @@ -439,9 +437,7 @@ def get_hash(

return decoded_hash

def get_hash_length(
self, model: type[models.Model], field: dict[str, models.Model | str | int]
) -> int:
def get_hash_length(self, model: type[models.Model], field: dict[str, BufferField]) -> int:
key = self._make_key(model, field)
return self._execute_redis_operation(key, RedisOperation.HASH_LENGTH)

Expand All @@ -455,7 +451,7 @@ def incr(
self,
model: type[models.Model],
columns: dict[str, int],
filters: dict[str, models.Model | str | int],
filters: dict[str, BufferField],
extra: dict[str, Any] | None = None,
signal_only: bool | None = None,
) -> None:
Expand Down
6 changes: 6 additions & 0 deletions src/sentry/options/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2820,6 +2820,12 @@
default=10000,
flags=FLAG_AUTOMATOR_MODIFIABLE,
)
register(
"delayed_processing.emit_logs",
type=Bool,
default=False,
flags=FLAG_AUTOMATOR_MODIFIABLE,
)
register(
"celery_split_queue_task_rollout",
default={},
Expand Down
160 changes: 160 additions & 0 deletions src/sentry/rules/processing/buffer_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
import math
import uuid
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from itertools import islice
from typing import ClassVar

from celery import Task

from sentry import buffer, options
from sentry.buffer.base import BufferField
from sentry.buffer.redis import BufferHookEvent, redis_buffer_registry
from sentry.db import models
from sentry.utils import metrics
from sentry.utils.registry import NoRegistrationExistsError, Registry

logger = logging.getLogger("sentry.delayed_processing")


@dataclass
class FilterKeys:
project_id: int


@dataclass
class BufferHashKeys:
model: type[models.Model]
filters: FilterKeys


class DelayedProcessingBase(ABC):
buffer_key: ClassVar[str]

def __init__(self, project_id: int):
self.project_id = project_id

@property
@abstractmethod
def hash_args(self) -> BufferHashKeys:
raise NotImplementedError

@property
@abstractmethod
def processing_task(self) -> Task:
raise NotImplementedError


delayed_processing_registry = Registry[type[DelayedProcessingBase]]()


def fetch_group_to_event_data(
project_id: int, model: type[models.Model], batch_key: str | None = None
) -> dict[str, str]:
field: dict[str, models.Model | int | str] = {
"project_id": project_id,
}

if batch_key:
field["batch_key"] = batch_key

return buffer.backend.get_hash(model=model, field=field)


def bucket_num_groups(num_groups: int) -> str:
if num_groups > 1:
magnitude = 10 ** int(math.log10(num_groups))
return f">{magnitude}"
return "1"


def process_in_batches(project_id: int, processing_type: str) -> None:
"""
This will check the number of alertgroup_to_event_data items in the Redis buffer for a project.

If the number is larger than the batch size, it will chunk the items and process them in batches.

The batches are replicated into a new redis hash with a unique filter (a uuid) to identify the batch.
We need to use a UUID because these batches can be created in multiple processes and we need to ensure
uniqueness across all of them for the centralized redis buffer. The batches are stored in redis because
we shouldn't pass objects that need to be pickled and 10k items could be problematic in the celery tasks
as arguments could be problematic. Finally, we can't use a pagination system on the data because
redis doesn't maintain the sort order of the hash keys.

`processing_task` will fetch the batch from redis and process the rules.
"""
batch_size = options.get("delayed_processing.batch_size")
saponifi3d marked this conversation as resolved.
Show resolved Hide resolved
should_emit_logs = options.get("delayed_processing.emit_logs")
log_format = "{}.{}"

try:
processing_info = delayed_processing_registry.get(processing_type)(project_id)
except NoRegistrationExistsError:
logger.exception(log_format.format(processing_type, "no_registration"))
return

hash_args = processing_info.hash_args
task = processing_info.processing_task
filters: dict[str, BufferField] = asdict(hash_args.filters)

event_count = buffer.backend.get_hash_length(model=hash_args.model, field=filters)
metrics.incr(
f"{processing_type}.num_groups", tags={"num_groups": bucket_num_groups(event_count)}
)

if event_count < batch_size:
return task.delay(project_id)

if should_emit_logs:
logger.info(
log_format.format(processing_type, "process_large_batch"),
extra={"project_id": project_id, "count": event_count},
)

# if the dictionary is large, get the items and chunk them.
alertgroup_to_event_data = fetch_group_to_event_data(project_id, hash_args.model)

with metrics.timer(f"{processing_type}.process_batch.duration"):
items = iter(alertgroup_to_event_data.items())

while batch := dict(islice(items, batch_size)):
batch_key = str(uuid.uuid4())

buffer.backend.push_to_hash_bulk(
model=hash_args.model,
filters={**filters, "batch_key": batch_key},
data=batch,
)

# remove the batched items from the project alertgroup_to_event_data
buffer.backend.delete_hash(**asdict(hash_args), fields=list(batch.keys()))

task.delay(project_id, batch_key)


def process_buffer() -> None:
fetch_time = datetime.now(tz=timezone.utc)
should_emit_logs = options.get("delayed_processing.emit_logs")

for processing_type, handler in delayed_processing_registry.registrations.items():
with metrics.timer(f"{processing_type}.process_all_conditions.duration"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a consistent prefix on the metrics for delayed_processing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly, i can clean this up in a follow up

project_ids = buffer.backend.get_sorted_set(
handler.buffer_key, min=0, max=fetch_time.timestamp()
)
if should_emit_logs:
log_str = ", ".join(
f"{project_id}: {timestamp}" for project_id, timestamp in project_ids
)
log_name = f"{processing_type}.project_id_list"
logger.info(log_name, extra={"project_ids": log_str})

for project_id, _ in project_ids:
process_in_batches(project_id, processing_type)

buffer.backend.delete_key(handler.buffer_key, min=0, max=fetch_time.timestamp())


if not redis_buffer_registry.has(BufferHookEvent.FLUSH):
redis_buffer_registry.add_handler(BufferHookEvent.FLUSH, process_buffer)
Loading
Loading