Skip to content

Commit

Permalink
Imp/update memory management (#918)
Browse files Browse the repository at this point in the history
* poetry updates

* update memory management

* fixed tests

* bulk operations
  • Loading branch information
ieaves authored May 2, 2024
1 parent 5354f1e commit 5d0e19c
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 106 deletions.
1 change: 1 addition & 0 deletions grai-server/app/connections/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def get_nodes_and_edges(self):
def run_update(self):
nodes, edges = self.integration.get_nodes_and_edges()
capture_quarantined_errors(self.integration, self.run)

update(self.run.workspace, self.run.source, nodes)
update(self.run.workspace, self.run.source, edges)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Generated by Django 4.2.11 on 2024-05-01 00:12

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("connections", "0030_alter_connector_options_connector_priority_and_more"),
]

operations = [
migrations.AlterField(
model_name="connector",
name="slug",
field=models.CharField(
blank=True,
choices=[
("postgres", "Postgres"),
("snowflake", "Snowflake"),
("dbt", "dbt"),
("dbt_cloud", "dbt Cloud"),
("yaml_file", "YAML"),
("mssql", "Microsoft SQL Server"),
("bigquery", "Google BigQuery"),
("fivetran", "Fivetran"),
("mysql", "MySQL"),
("redshift", "Amazon Redshift"),
("metabase", "Metabase"),
("looker", "Looker"),
("openlineage", "OpenLineage"),
("flatfile", "Flat File"),
("cube", "Cube"),
],
max_length=255,
null=True,
),
),
migrations.AlterField(
model_name="run",
name="status",
field=models.CharField(
choices=[("pending", "Pending"), ("running", "Running"), ("success", "Success"), ("error", "Error")],
default="pending",
max_length=255,
),
),
]
10 changes: 9 additions & 1 deletion grai-server/app/connections/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.dispatch import receiver
from django.db.models.signals import pre_save
from django.utils import timezone
from enum import Enum


class ConnectorSlugs(models.TextChoices):
Expand Down Expand Up @@ -172,6 +173,13 @@ def save(self, *args, **kwargs):
task.delete()


class RunStatus(models.TextChoices):
PENDING = "pending", "Pending"
RUNNING = "running", "Running"
SUCCESS = "success", "Success"
ERROR = "error", "Error"


class Run(TenantModel):
TESTS = "tests"
UPDATE = "update"
Expand Down Expand Up @@ -202,7 +210,7 @@ class Run(TenantModel):
blank=True,
null=True,
)
status = models.CharField(max_length=255)
status = models.CharField(max_length=255, choices=RunStatus.choices, default=RunStatus.PENDING)
metadata = models.JSONField(default=dict)
workspace = models.ForeignKey(
"workspaces.Workspace",
Expand Down
111 changes: 91 additions & 20 deletions grai-server/app/connections/task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Protocol,
Expand All @@ -16,7 +17,7 @@
Union,
)
from uuid import UUID

from time import sleep
from django.contrib.postgres.aggregates import ArrayAgg
from django.db import models
from django.db.models import Q, Value
Expand All @@ -38,6 +39,8 @@
from .adapters.schemas import model_to_schema, schema_to_model
from itertools import islice
from pympler import asizeof
from functools import reduce
from django.db.models import Subquery


class NameNamespace(Protocol):
Expand All @@ -62,10 +65,17 @@ class SpecNameNamespace(Protocol):


def to_dict(instance):
"""
Shallow conversion of a model instance to a dictionary.
This is useful for merging model instances but should not be relied on for serialization
"""
opts = instance._meta
data = {}
for f in chain(opts.concrete_fields, opts.private_fields):
data[f.name] = f.value_from_object(instance)
data = {
f.name: getattr(instance, f.name) if hasattr(instance, f.name) else f.value_from_object(instance)
for f in chain(opts.concrete_fields, opts.private_fields)
}
# for f in chain(opts.concrete_fields, opts.private_fields):
# data[f.name] = f.value_from_object(instance)
for f in opts.many_to_many:
data[f.name] = [i.id for i in f.value_from_object(instance)]
return data
Expand All @@ -89,7 +99,7 @@ def merge_node_dict(a: models.Model, b: Dict) -> models.Model:
@merge.register
def merge_node_node(a: models.Model, b: models.Model) -> models.Model:
assert isinstance(a, type(b))
return type(a)(merge(to_dict(a), to_dict(b)))
return type(a)(**merge(to_dict(a), to_dict(b)))


def get_node(workspace: Workspace, grai_type: NameNamespaceDict) -> NodeModel:
Expand Down Expand Up @@ -296,6 +306,37 @@ def create_batches(data: list, threshold_size=500 * 1024 * 1024) -> list:
yield batch


def create_dict_batches(data: list, threshold_size=500 * 1024 * 1024) -> Dict:
batch = {}
current_batch_size = 0
for item in data:
item_size = asizeof.asizeof(item)
if current_batch_size + item_size > threshold_size and batch:
yield batch
batch = {}
current_batch_size = 0
batch[(item.name, item.namespace)] = item
current_batch_size += item_size
if batch:
yield batch


def valid_items(items: List[NodeModel | EdgeModel], workspace: Workspace) -> Iterable[NodeModel | EdgeModel]:
seen_keys = set()
for item in items:
if item.workspace != workspace:
raise ValueError(
f"Items in the batch must all belong to the same workspace.",
f"Expected workspace id {workspace.id}, got {item.workspace.id}",
)
key = (item.name, item.namespace)
if key in seen_keys:
warnings.warn(f"Multiple {type(item)} items with unique (name, namespace): {key} detected in batch.")
else:
seen_keys.add(key)
yield item


def update(
workspace: Workspace,
source: Source,
Expand All @@ -310,27 +351,57 @@ def update(
is_node = items[0].type in ["Node", "SourceNode"]
Model = NodeModel if is_node else EdgeModel
relationship = source.nodes if is_node else source.edges
through_label = "node_id" if is_node else "edge_id"
threshold_bytes = 200 * 1024 * 1024

items = (schema_to_model(item, workspace) for item in items)
found_items = []
for batch in create_dict_batches(valid_items(items, workspace), threshold_bytes):
# Update existing items
updated_item_keys = set()
existing_item_filter = reduce(
lambda q, key: q | Q(name=key[0], namespace=key[1], workspace=workspace), batch.keys(), Q()
)
updated_items = [
merge(item, batch[(item.name, item.namespace)])
for item in Model.objects.filter(existing_item_filter).iterator()
]
del existing_item_filter

Model.objects.bulk_update(updated_items, ["metadata", "display_name"])
for item in updated_items:
batch[(item.name, item.namespace)] = item
updated_item_keys.add((item.name, item.namespace))

# Create new items
new_items = (item for item in batch.values() if (item.name, item.namespace) not in updated_item_keys)
for item in Model.objects.bulk_create(new_items):
batch[(item.name, item.namespace)] = item

# Create foreign keys to source
through_items = (
relationship.through(source_id=source.id, **{through_label: item.id}) for item in batch.values()
)
relationship.through.objects.bulk_create(through_items, ignore_conflicts=True)

new_items, deactivated_items, updated_items = process_updates(workspace, source, items, active_items)

# relationship creationcan be improved with a switch to a bulk_create on the through entity
# https://stackoverflow.com/questions/68422898/efficiently-bulk-updating-many-manytomany-fields
for batch in create_batches(new_items):
Model.objects.bulk_create(batch)
for batch in create_batches(new_items):
Model.objects.bulk_update(batch, ["metadata"])
found_items.extend([item.id for item in batch.values()])
del batch

with transaction.atomic():
relationship.add(*new_items, *updated_items)
# Remove old source relations.
num_deleted, _ = (
relationship.through.objects.filter(source_id=source.id)
.exclude(**{f"{through_label}__in": found_items})
.delete()
)

if len(deactivated_items) > 0:
relationship.remove(*deactivated_items)
if num_deleted > 0:
empty_source_query = Q(workspace=workspace, data_sources=None)

deletable_nodes = NodeModel.objects.filter(empty_source_query)
deleted_edge_query = Q(source__in=deletable_nodes) | Q(destination__in=deletable_nodes) | empty_source_query
deletable_nodes_subquery = Subquery(deletable_nodes.values("id"))
EdgeModel.objects.filter(
Q(source__in=deletable_nodes_subquery) | Q(destination__in=deletable_nodes_subquery) | empty_source_query
).delete()

EdgeModel.objects.filter(deleted_edge_query).delete()
deletable_nodes.delete()


Expand Down
1 change: 1 addition & 0 deletions grai-server/app/connections/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def execute_run(run: Run):
)

run.status = "success"

run.finished_at = timezone.now()
run.save()

Expand Down
4 changes: 2 additions & 2 deletions grai-server/app/connections/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import uuid
from datetime import date

from time import sleep
import pytest
from decouple import config
from django.conf import settings
Expand Down Expand Up @@ -218,10 +218,10 @@ def test_run_update_server_postgres(self, test_workspace, test_postgres_connecto
},
secrets={"password": "grai"},
)

run = Run.objects.create(connection=connection, workspace=test_workspace, source=test_source)

process_run(str(run.id))

run.refresh_from_db()

assert run.status == "success"
Expand Down
Loading

0 comments on commit 5d0e19c

Please sign in to comment.