Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIKE] feat: refresh record count_submitted_responses column using context functions #5106

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d5b762c
feat: add dataset support to be created using distribution settings
jfcalvo Jun 13, 2024
96a0bde
feat: use dataset_attrs as dictionary to avoid the use of DatasetCrea…
jfcalvo Jun 13, 2024
04e9e92
feat: add DatasetCreateValidator and move some validations from context
jfcalvo Jun 13, 2024
5eb30fc
feat: add dataset support to update distribution settings
jfcalvo Jun 14, 2024
6953fb7
chore: update CHANGELOG.md
jfcalvo Jun 14, 2024
2082e12
feat: add additional test to check invalid distribution strategy
jfcalvo Jun 14, 2024
6cfcf43
chore: fix typo
jfcalvo Jun 14, 2024
abc3693
chore: move change to the right section at CHANGELOG.md
jfcalvo Jun 14, 2024
cf60aaa
Merge branch 'feat/create-datasets-with-distribution' into feat/updat…
jfcalvo Jun 14, 2024
f1090d5
chore: update CHANGELOG.md
jfcalvo Jun 14, 2024
acce365
Merge branch 'feat/add-dataset-automatic-task-distribution' into feat…
jfcalvo Jun 19, 2024
f2a3f08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
8e5a7c0
Merge branch 'feat/create-datasets-with-distribution' into feat/updat…
jfcalvo Jun 19, 2024
5671f49
feat: refresh record count_submitted_responses column
jfcalvo Jun 25, 2024
3d51d51
Merge branch 'feat/add-dataset-automatic-task-distribution' into feat…
jfcalvo Jun 28, 2024
1ffe9c8
Merge branch 'feat/create-datasets-with-distribution' into feat/updat…
jfcalvo Jun 28, 2024
382e892
Merge branch 'feat/update-datasets-distribution' into feat/refresh-re…
jfcalvo Jun 28, 2024
97dc916
chore: fix migration
jfcalvo Jun 28, 2024
b23f3d8
Merge branch 'feat/create-datasets-with-distribution' into feat/updat…
jfcalvo Jun 28, 2024
5adc94b
Merge branch 'feat/update-datasets-distribution' into feat/refresh-re…
jfcalvo Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ These are the section headers that we use:

## [Unreleased]()

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)
### Added

- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013))
- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028))

### Changed

- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

### Removed

- Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add record metadata column
"""add metadata column to records table
Revision ID: 3ff6484f8b37
Revises: ae5522b4c674
Expand All @@ -31,12 +31,8 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("records", "metadata")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add distribution column to datasets table
Revision ID: 45a12f74448b
Revises: d00f819ccc67
Create Date: 2024-06-13 11:23:43.395093
"""

import json

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "45a12f74448b"
down_revision = "d00f819ccc67"
branch_labels = None
depends_on = None

DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1})


def upgrade() -> None:
op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True))
op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'")
with op.batch_alter_table("datasets") as batch_op:
batch_op.alter_column("distribution", nullable=False)


def downgrade() -> None:
op.drop_column("datasets", "distribution")
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add count_submitted_responses to records table

Revision ID: b4e101b124d2
Revises: 45a12f74448b
Create Date: 2024-06-24 17:07:18.614728

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "b4e101b124d2"
down_revision = "45a12f74448b"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column("records", sa.Column("count_submitted_responses", sa.Integer(), server_default="0", nullable=False))
op.execute("""
UPDATE records
SET count_submitted_responses = (
SELECT COUNT(*)
FROM responses
WHERE responses.record_id = records.id AND responses.status = 'submitted'
)
""")


def downgrade() -> None:
op.drop_column("records", "count_submitted_responses")
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add allow_extra_metadata column to dataset table
"""add allow_extra_metadata column to datasets table
Revision ID: b8458008b60e
Revises: 7cbcccf8b57a
Expand All @@ -31,14 +31,10 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("datasets", "allow_extra_metadata")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def create_dataset(
):
await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id))

return await datasets.create_dataset(db, dataset_create)
return await datasets.create_dataset(db, dataset_create.dict())


@router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field)
Expand Down Expand Up @@ -302,4 +302,4 @@ async def update_dataset(

await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update)
return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))
38 changes: 35 additions & 3 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

from datetime import datetime
from typing import List, Optional
from typing import List, Literal, Optional, Union
from uuid import UUID

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetStatus
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr

try:
Expand All @@ -44,6 +44,32 @@
]


class DatasetOverlapDistribution(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int


DatasetDistribution = DatasetOverlapDistribution


class DatasetOverlapDistributionCreate(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int = Field(
ge=1,
description="Minimum number of submitted responses to consider a record as completed",
)


DatasetDistributionCreate = DatasetOverlapDistributionCreate


class DatasetOverlapDistributionUpdate(DatasetDistributionCreate):
pass


DatasetDistributionUpdate = DatasetOverlapDistributionUpdate


class RecordMetrics(BaseModel):
count: int

Expand Down Expand Up @@ -74,6 +100,7 @@ class Dataset(BaseModel):
guidelines: Optional[str]
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
Expand All @@ -91,12 +118,17 @@ class DatasetCreate(BaseModel):
name: DatasetName
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: bool = True
distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate(
strategy=DatasetDistributionStrategy.overlap,
min_submitted=1,
)
workspace_id: UUID


class DatasetUpdate(UpdateSchema):
name: Optional[DatasetName]
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata"}
__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
74 changes: 51 additions & 23 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla_server.api.schemas.v1.datasets import (
DatasetCreate,
DatasetProgress,
)
from argilla_server.api.schemas.v1.datasets import DatasetProgress
from argilla_server.api.schemas.v1.fields import FieldCreate
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate
from argilla_server.api.schemas.v1.records import (
Expand Down Expand Up @@ -82,6 +79,7 @@
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.datasets import DatasetCreateValidator
from argilla_server.validators.responses import (
ResponseCreateValidator,
ResponseUpdateValidator,
Expand Down Expand Up @@ -122,22 +120,18 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->
return result.scalars().all()


async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate):
if await Workspace.get(db, dataset_create.workspace_id) is None:
raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found")
async def create_dataset(db: AsyncSession, dataset_attrs: dict):
dataset = Dataset(
name=dataset_attrs["name"],
guidelines=dataset_attrs["guidelines"],
allow_extra_metadata=dataset_attrs["allow_extra_metadata"],
distribution=dataset_attrs["distribution"],
workspace_id=dataset_attrs["workspace_id"],
)

if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id):
raise NotUniqueError(
f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`"
)
await DatasetCreateValidator.validate(db, dataset)

return await Dataset.create(
db,
name=dataset_create.name,
guidelines=dataset_create.guidelines,
allow_extra_metadata=dataset_create.allow_extra_metadata,
workspace_id=dataset_create.workspace_id,
)
return await dataset.save(db)


async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int:
Expand Down Expand Up @@ -176,6 +170,10 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset:
return await dataset.update(db, **dataset_attrs)


async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
async with db.begin_nested():
dataset = await dataset.delete(db, autocommit=False)
Expand All @@ -186,11 +184,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset:
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset:
params = dataset_update.dict(exclude_unset=True)
return await dataset.update(db, **params)


async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field:
if dataset.is_ready:
raise UnprocessableEntityError("Field cannot be created for a published dataset")
Expand Down Expand Up @@ -943,6 +936,15 @@ async def create_response(
)

await db.flush([response])

if response_create.status == ResponseStatus.submitted:
await db.execute(
sqlalchemy.update(Record)
.where(Record.id == record.id)
.values(count_submitted_responses=Record.count_submitted_responses + 1)
)
# TODO: refresh record at search engine

await _touch_dataset_last_activity_at(db, record.dataset)
await search_engine.update_record_response(response)

Expand All @@ -957,6 +959,8 @@ async def update_response(
ResponseUpdateValidator(response_update).validate_for(response.record)

async with db.begin_nested():
previous_response_status = response.status

response = await response.update(
db,
values=jsonable_encoder(response_update.values),
Expand All @@ -966,6 +970,21 @@ async def update_response(
)

await _load_users_from_responses(response)

if response_update.status == ResponseStatus.submitted and previous_response_status != ResponseStatus.submitted:
await db.execute(
sqlalchemy.update(Record)
.where(Record.id == response.record_id)
.values(count_submitted_responses=Record.count_submitted_responses + 1)
)
if response_update.status != ResponseStatus.submitted and previous_response_status == ResponseStatus.submitted:
await db.execute(
sqlalchemy.update(Record)
.where(Record.id == response.record_id)
.values(count_submitted_responses=Record.count_submitted_responses - 1)
)
# TODO: refresh record at search engine

await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.update_record_response(response)

Expand Down Expand Up @@ -1007,6 +1026,15 @@ async def delete_response(db: AsyncSession, search_engine: SearchEngine, respons
async with db.begin_nested():
response = await response.delete(db, autocommit=False)
await _load_users_from_responses(response)

if response.status == ResponseStatus.submitted:
await db.execute(
sqlalchemy.update(Record)
.where(Record.id == response.record_id)
.values(count_submitted_responses=Record.count_submitted_responses - 1)
)
# TODO: refresh record at search engine

await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.delete_record_response(response)

Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class DatasetStatus(str, Enum):
ready = "ready"


class DatasetDistributionStrategy(str, Enum):
overlap = "overlap"


class UserRole(str, Enum):
owner = "owner"
admin = "admin"
Expand Down
Loading