Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dataset support to be updated using distribution settings #5028

1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ These are the section headers that we use:
### Added

- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013))
- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028))

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,4 @@ async def update_dataset(

await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update)
return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))
10 changes: 9 additions & 1 deletion argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class DatasetOverlapDistributionCreate(BaseModel):
DatasetDistributionCreate = DatasetOverlapDistributionCreate


class DatasetOverlapDistributionUpdate(DatasetDistributionCreate):
pass


DatasetDistributionUpdate = DatasetOverlapDistributionUpdate


class RecordMetrics(BaseModel):
count: int

Expand Down Expand Up @@ -122,5 +129,6 @@ class DatasetUpdate(UpdateSchema):
name: Optional[DatasetName]
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata"}
__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
13 changes: 7 additions & 6 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.datasets import DatasetCreateValidator
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator
from argilla_server.validators.responses import (
ResponseCreateValidator,
ResponseUpdateValidator,
Expand Down Expand Up @@ -170,6 +170,12 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset:
await DatasetUpdateValidator.validate(db, dataset, dataset_attrs)

return await dataset.update(db, **dataset_attrs)


async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
async with db.begin_nested():
dataset = await dataset.delete(db, autocommit=False)
Expand All @@ -180,11 +186,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset:
return dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset:
params = dataset_update.dict(exclude_unset=True)
return await dataset.update(db, **params)


async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field:
if dataset.is_ready:
raise UnprocessableEntityError("Field cannot be created for a published dataset")
Expand Down
17 changes: 14 additions & 3 deletions argilla-server/src/argilla_server/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,27 @@

class DatasetCreateValidator:
@classmethod
async def validate(cls, db, dataset: Dataset) -> None:
async def validate(cls, db: AsyncSession, dataset: Dataset) -> None:
await cls._validate_workspace_is_present(db, dataset.workspace_id)
await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id)

@classmethod
async def _validate_workspace_is_present(cls, db, workspace_id: UUID) -> None:
async def _validate_workspace_is_present(cls, db: AsyncSession, workspace_id: UUID) -> None:
if await Workspace.get(db, workspace_id) is None:
raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found")

@classmethod
async def _validate_name_is_not_duplicated(cls, db, name: str, workspace_id: UUID) -> None:
async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, workspace_id: UUID) -> None:
if await Dataset.get_by(db, name=name, workspace_id=workspace_id):
raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`")


class DatasetUpdateValidator:
@classmethod
async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if I don't need the class method to be asynchronous, and not needing the session I prefer to have the same signature in all of them.

cls._validate_distribution(dataset, dataset_attrs)

@classmethod
def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None:
if dataset.is_ready and dataset_attrs.get("distribution") is not None:
raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset")
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID

import pytest
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from httpx import AsyncClient

from tests.factories import DatasetFactory


@pytest.mark.asyncio
class TestUpdateDataset:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/datasets/{dataset_id}"

async def test_update_dataset_distribution(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create()

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
}

assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
}

async def test_update_dataset_without_distribution(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create()

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

assert dataset.name == "Dataset updated name"
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_without_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

assert dataset.name == "Dataset updated name"
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
},
)

assert response.status_code == 422
assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"}

assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_with_invalid_strategy(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"distribution": {
"strategy": "invalid_strategy",
},
},
)

assert response.status_code == 422
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_with_invalid_min_submitted_value(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 0,
},
},
)

assert response.status_code == 422
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_as_none(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create()

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"distribution": None},
)

assert response.status_code == 422
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}