Skip to content

Commit

Permalink
feat: add dataset support to be updated using distribution settings (#…
Browse files Browse the repository at this point in the history
…5028)

# Description

This PR add changes to support update dataset distribution settings.
Allowing for example to update `min_submitted` attribute when `overlap`
distribution strategy is in use.

Closes #5010 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [x] Adding new tests.

**Checklist**

- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Paco Aranda <[email protected]>
  • Loading branch information
jfcalvo and frascuchon authored Jul 1, 2024
1 parent 97dc916 commit 91eb6b1
Show file tree
Hide file tree
Showing 31 changed files with 979 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ export class RecordRepository {
constructor(private readonly axios: NuxtAxiosInstance) {}

getRecords(criteria: RecordCriteria): Promise<BackendRecords> {
if (criteria.isFilteringByAdvanceSearch)
return this.getRecordsByAdvanceSearch(criteria);

return this.getRecordsByDatasetId(criteria);
return this.getRecordsByAdvanceSearch(criteria);
// return this.getRecordsByDatasetId(criteria);
}

async getRecord(recordId: string): Promise<BackendRecord> {
Expand Down Expand Up @@ -264,6 +262,30 @@ export class RecordRepository {
};
}

body.filters = {
and: [
{
type: "terms",
scope: {
entity: "response",
property: "status",
},
values: [status],
},
],
};

if (status === "pending") {
body.filters.and.push({
type: "terms",
scope: {
entity: "record",
property: "status",
},
values: ["pending"],
});
}

if (
isFilteringByMetadata ||
isFilteringByResponse ||
Expand Down
1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ These are the section headers that we use:
### Added

- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013))
- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028))

### Changed

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add status column to records table
Revision ID: 237f7c674d74
Revises: 45a12f74448b
Create Date: 2024-06-18 17:59:36.992165
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "237f7c674d74"
down_revision = "45a12f74448b"
branch_labels = None
depends_on = None


record_status_enum = sa.Enum("pending", "completed", name="record_status_enum")


def upgrade() -> None:
record_status_enum.create(op.get_bind())

op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False))
op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False)

# NOTE: Updating existent records to have "completed" status when they have
# at least one response with "submitted" status.
op.execute("""
UPDATE records
SET status = 'completed'
WHERE id IN (
SELECT DISTINCT record_id
FROM responses
WHERE status = 'submitted'
);
""")


def downgrade() -> None:
op.drop_index(op.f("ix_records_status"), table_name="records")
op.drop_column("records", "status")

record_status_enum.drop(op.get_bind())
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,4 @@ async def update_dataset(

await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update)
return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ async def update_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.update(response))
Expand All @@ -83,7 +85,9 @@ async def delete_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.delete(response))
Expand Down
10 changes: 9 additions & 1 deletion argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class DatasetOverlapDistributionCreate(BaseModel):
DatasetDistributionCreate = DatasetOverlapDistributionCreate


class DatasetOverlapDistributionUpdate(DatasetDistributionCreate):
pass


DatasetDistributionUpdate = DatasetOverlapDistributionUpdate


class RecordMetrics(BaseModel):
count: int

Expand Down Expand Up @@ -122,5 +129,6 @@ class DatasetUpdate(UpdateSchema):
name: Optional[DatasetName]
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata"}
__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
5 changes: 3 additions & 2 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName
from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus
from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator
from argilla_server.pydantic_v1.utils import GetterDict
from argilla_server.search_engine import TextQuery
Expand Down Expand Up @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any:

class Record(BaseModel):
id: UUID
status: RecordStatus
fields: Dict[str, Any]
metadata: Optional[Dict[str, Any]]
external_id: Optional[str]
Expand Down Expand Up @@ -196,7 +197,7 @@ def _has_relationships(self):

class RecordFilterScope(BaseModel):
entity: Literal["record"]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]]


class Records(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
from argilla_server.contexts import distribution
from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict
from argilla_server.contexts.records import (
fetch_records_by_external_ids_as_dict,
Expand Down Expand Up @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr

await self._upsert_records_relationships(records, bulk_create.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp

await self._upsert_records_relationships(records, bulk_upsert.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record
.filter(Record.id.in_([record.id for record in records]))
.options(
selectinload(Record.responses).selectinload(Response.user),
selectinload(Record.responses_submitted),
selectinload(Record.suggestions).selectinload(Suggestion.question),
selectinload(Record.vectors),
)
Expand Down
28 changes: 21 additions & 7 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
VectorSettingsCreate,
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.contexts import accounts
from argilla_server.contexts import accounts, distribution
from argilla_server.enums import DatasetStatus, RecordInclude, UserRole
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Expand All @@ -79,7 +79,7 @@
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.datasets import DatasetCreateValidator
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator
from argilla_server.validators.responses import (
ResponseCreateValidator,
ResponseUpdateValidator,
Expand Down Expand Up @@ -170,6 +170,12 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset:
await DatasetUpdateValidator.validate(db, dataset, dataset_attrs)

return await dataset.update(db, **dataset_attrs)


async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
async with db.begin_nested():
dataset = await dataset.delete(db, autocommit=False)
Expand All @@ -180,11 +186,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset:
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset:
params = dataset_update.dict(exclude_unset=True)
return await dataset.update(db, **params)


async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field:
if dataset.is_ready:
raise UnprocessableEntityError("Field cannot be created for a published dataset")
Expand Down Expand Up @@ -939,6 +940,9 @@ async def create_response(
await db.flush([response])
await _touch_dataset_last_activity_at(db, record.dataset)
await search_engine.update_record_response(response)
await db.refresh(record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, record)
await search_engine.partial_record_update(record, status=record.status)

await db.commit()

Expand All @@ -962,6 +966,9 @@ async def update_response(
await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.update_record_response(response)
await db.refresh(response.record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, response.record)
await search_engine.partial_record_update(response.record, status=response.record.status)

await db.commit()

Expand Down Expand Up @@ -991,6 +998,9 @@ async def upsert_response(
await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.update_record_response(response)
await db.refresh(record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, record)
await search_engine.partial_record_update(record, status=record.status)

await db.commit()

Expand All @@ -1000,9 +1010,13 @@ async def upsert_response(
async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response:
async with db.begin_nested():
response = await response.delete(db, autocommit=False)

await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.delete_record_response(response)
await db.refresh(response.record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, response.record)
await search_engine.partial_record_update(record=response.record, status=response.record.status)

await db.commit()

Expand Down
42 changes: 42 additions & 0 deletions argilla-server/src/argilla_server/contexts/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.enums import DatasetDistributionStrategy, RecordStatus
from argilla_server.models import Record


# TODO: Do this with one single update statement for all records if possible to avoid too many queries.
async def update_records_status(db: AsyncSession, records: List[Record]):
for record in records:
await update_record_status(db, record)


async def update_record_status(db: AsyncSession, record: Record) -> Record:
if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap:
return await _update_record_status_with_overlap_strategy(db, record)

raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`")


async def _update_record_status_with_overlap_strategy(db: AsyncSession, record: Record) -> Record:
if len(record.responses_submitted) >= record.dataset.distribution["min_submitted"]:
record.status = RecordStatus.completed
else:
record.status = RecordStatus.pending

return await record.save(db, autocommit=False)
5 changes: 5 additions & 0 deletions argilla-server/src/argilla_server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class UserRole(str, Enum):
annotator = "annotator"


class RecordStatus(str, Enum):
pending = "pending"
completed = "completed"


class RecordInclude(str, Enum):
responses = "responses"
suggestions = "suggestions"
Expand Down
Loading

0 comments on commit 91eb6b1

Please sign in to comment.