Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIKE] feat: refresh record status column using SQLAlchemy event listeners #5076

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ async def create_response(
)

await db.flush([response])
await distribution.refresh_record_status(db, record, autocommit=False)
# await distribution.refresh_record_status(db, record, autocommit=False)
await _touch_dataset_last_activity_at(db, record.dataset)
await search_engine.index_records(record.dataset, [record])
await search_engine.update_record_response(response)
Expand All @@ -961,7 +961,7 @@ async def update_response(
)

await _load_users_from_responses(response)
await distribution.refresh_record_status(db, response.record, autocommit=False)
# await distribution.refresh_record_status(db, response.record, autocommit=False)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.index_records(response.record.dataset, [response.record])
await search_engine.update_record_response(response)
Expand Down Expand Up @@ -992,7 +992,7 @@ async def upsert_response(
)

await _load_users_from_responses(response)
await distribution.refresh_record_status(db, response.record, autocommit=False)
# await distribution.refresh_record_status(db, response.record, autocommit=False)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.index_records(response.record.dataset, [response.record])
await search_engine.update_record_response(response)
Expand All @@ -1006,7 +1006,7 @@ async def delete_response(db: AsyncSession, search_engine: SearchEngine, respons
async with db.begin_nested():
response = await response.delete(db, autocommit=False)
await _load_users_from_responses(response)
await distribution.refresh_record_status(db, response.record, autocommit=False)
# await distribution.refresh_record_status(db, response.record, autocommit=False)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.index_records(response.record.dataset, [response.record])
await search_engine.delete_record_response(response)
Expand Down
41 changes: 41 additions & 0 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,44 @@ def __repr__(self):
f"username={self.username!r}, role={self.role.value!r}, "
f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})"
)


from sqlalchemy import event
from sqlalchemy.orm import Session


def refresh_record_status(db: Session, record: Record) -> None:
if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap:
return _refresh_record_status_with_overlap_strategy(db, record)

raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`")


def _refresh_record_status_with_overlap_strategy(db: Session, record: Record) -> None:
count_record_submitted_responses = (
db.query(Response).filter_by(status=ResponseStatus.submitted, record_id=record.id).count()
)

if count_record_submitted_responses >= record.dataset.distribution["min_submitted"]:
record.status = RecordStatus.completed
else:
record.status = RecordStatus.pending

db.add(record)


@event.listens_for(Response, "after_insert")
@event.listens_for(Response, "after_update")
@event.listens_for(Response, "after_delete")
def refresh_record_status_listener(mapper, connection, response: Response):
from sqlalchemy.orm import Session

session = Session.object_session(response)
if session is None:
return

record = session.query(Record).get(response.record_id)
if record is None:
return

refresh_record_status(session, record)
Loading