diff --git a/DashAI/back/api/api_v1/endpoints/datasets.py b/DashAI/back/api/api_v1/endpoints/datasets.py index 5caed79f..84aebb60 100644 --- a/DashAI/back/api/api_v1/endpoints/datasets.py +++ b/DashAI/back/api/api_v1/endpoints/datasets.py @@ -3,6 +3,8 @@ import shutil from typing import Any, Dict +import pyarrow as pa +import pyarrow.ipc as ipc from fastapi import APIRouter, Depends, Response, status from fastapi.exceptions import HTTPException from kink import di, inject @@ -11,10 +13,8 @@ from DashAI.back.api.api_v1.schemas.datasets_params import DatasetUpdateParams from DashAI.back.dataloaders.classes.dashai_dataset import ( - DashAIDataset, get_columns_spec, get_dataset_info, - load_dataset, update_columns_spec, ) from DashAI.back.dependencies.database.models import Dataset @@ -126,8 +126,16 @@ async def get_sample( status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found", ) - dataset: DashAIDataset = load_dataset(f"{file_path}/dataset") - sample = dataset.sample(n=10) + + arrow_path = os.path.join(file_path, "dataset", "data.arrow") + + with pa.OSFile(arrow_path, "rb") as source: + reader = ipc.open_file(source) + batch = reader.get_batch(0) + sample_size = min(10, batch.num_rows) + sample_batch = batch.slice(0, sample_size) + sample = sample_batch.to_pydict() + except exc.SQLAlchemyError as e: logger.exception(e) raise HTTPException( diff --git a/DashAI/back/api/api_v1/endpoints/experiments.py b/DashAI/back/api/api_v1/endpoints/experiments.py index 61c057ac..96ce3d5b 100644 --- a/DashAI/back/api/api_v1/endpoints/experiments.py +++ b/DashAI/back/api/api_v1/endpoints/experiments.py @@ -1,6 +1,9 @@ import logging +import os from typing import Union +import pyarrow as pa +import pyarrow.ipc as ipc from fastapi import APIRouter, Depends, Response, status from fastapi.exceptions import HTTPException from kink import di, inject @@ -11,10 +14,7 @@ ColumnsValidationParams, ExperimentParams, ) -from DashAI.back.dataloaders.classes.dashai_dataset import ( - get_column_names_from_indexes, - load_dataset, -) +from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset from DashAI.back.dependencies.database.models import Dataset, Experiment from DashAI.back.dependencies.registry import ComponentRegistry from DashAI.back.tasks.base_task import BaseTask @@ -101,6 +101,7 @@ async def validate_columns( component_registry: ComponentRegistry = Depends(lambda: di["component_registry"]), session_factory: sessionmaker = Depends(lambda: di["session_factory"]), ): + """Validate if dataset columns are compatible with a task.""" with session_factory() as db: try: dataset = db.get(Dataset, params.dataset_id) @@ -109,18 +110,28 @@ async def validate_columns( status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found", ) - datasetdict = load_dataset(f"{dataset.file_path}/dataset") - if not datasetdict: + + dataset_path = f"{dataset.file_path}/dataset" + data_filepath = os.path.join(dataset_path, "data.arrow") + with pa.OSFile(data_filepath, "rb") as source: + reader = ipc.open_file(source) + batch = reader.get_batch(0) + sample_size = min(5, batch.num_rows) + sample_batch = batch.slice(0, sample_size) + + table = pa.Table.from_batches([sample_batch]) + minimal_dataset = DashAIDataset(table) + + column_names = minimal_dataset.column_names + + if max(params.inputs_columns + params.outputs_columns) > len(column_names): raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Error while loading the dataset.", + status_code=status.HTTP_400_BAD_REQUEST, + detail="Column index out of range", ) - inputs_names = get_column_names_from_indexes( - dataset=datasetdict, indexes=params.inputs_columns - ) - outputs_names = get_column_names_from_indexes( - dataset=datasetdict, indexes=params.outputs_columns - ) + + inputs_names = [column_names[i - 1] for i in params.inputs_columns] + outputs_names = [column_names[i - 1] for i in params.outputs_columns] except exc.SQLAlchemyError as e: log.exception(e) @@ -128,16 +139,19 @@ async def validate_columns( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal database error", ) from e + if params.task_name not in component_registry: raise HTTPException( status_code=404, detail=f"Task {params.task_name} not found in the registry.", ) + task: BaseTask = component_registry[params.task_name]["class"]() validation_response = {} + try: prepared_dataset = task.prepare_for_task( - datasetdict=datasetdict, + datasetdict=minimal_dataset, outputs_columns=outputs_names, ) task.validate_dataset_for_task( @@ -186,18 +200,22 @@ async def create_experiment( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found" ) - datasetdict = load_dataset(f"{dataset.file_path}/dataset") - if not datasetdict: + dataset_path = f"{dataset.file_path}/dataset" + data_filepath = os.path.join(dataset_path, "data.arrow") + + with pa.OSFile(data_filepath, "rb") as source: + reader = ipc.open_file(source) + schema = reader.schema + column_names = schema.names + + if max(params.input_columns + params.output_columns) > len(column_names): raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Error while loading the dataset.", + status_code=status.HTTP_400_BAD_REQUEST, + detail="Column index out of range", ) - inputs_columns = get_column_names_from_indexes( - dataset=datasetdict, indexes=params.input_columns - ) - outputs_columns = get_column_names_from_indexes( - dataset=datasetdict, indexes=params.output_columns - ) + + inputs_columns = [column_names[i - 1] for i in params.input_columns] + outputs_columns = [column_names[i - 1] for i in params.output_columns] experiment = Experiment( dataset_id=params.dataset_id, task_name=params.task_name, diff --git a/DashAI/back/dataloaders/classes/dashai_dataset.py b/DashAI/back/dataloaders/classes/dashai_dataset.py index e0deab11..2a2910d3 100644 --- a/DashAI/back/dataloaders/classes/dashai_dataset.py +++ b/DashAI/back/dataloaders/classes/dashai_dataset.py @@ -9,6 +9,7 @@ import pyarrow.ipc as ipc from beartype import beartype from datasets import ClassLabel, Dataset, DatasetDict, Value, concatenate_datasets +from datasets.features import Features from sklearn.model_selection import train_test_split @@ -93,21 +94,6 @@ def keys(self) -> List[str]: return list(self.splits["split_indices"].keys()) return [] - @beartype - def save_to_disk(self, dataset_path: Union[str, os.PathLike]) -> None: - """ - Overrides the default save_to_disk method to save the dataset as - a single directory with: - - "data.arrow": the dataset's Arrow table. - - "splits.json": the dataset's splits (e.g., original split indices). - - Parameters - ---------- - dataset_path : Union[str, os.PathLike] - path where the dataset will be saved - """ - save_dataset(self, dataset_path) - @beartype def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset": """Change the type of some columns. @@ -678,16 +664,22 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]: Dict Dict with the columns and types """ - dataset = load_dataset(dataset_path) - dataset_features = dataset.features + + data_filepath = os.path.join(dataset_path, "data.arrow") + with pa.OSFile(data_filepath, "rb") as source: + reader = ipc.open_file(source) + schema = reader.schema + + features = Features.from_arrow_schema(schema) + column_types = {} - for column in dataset_features: - if dataset_features[column]._type == "Value": + for column, feature in features.items(): + if feature._type == "Value": column_types[column] = { "type": "Value", - "dtype": dataset_features[column].dtype, + "dtype": feature.dtype, } - elif dataset_features[column]._type == "ClassLabel": + elif feature._type == "ClassLabel": column_types[column] = { "type": "Classlabel", "dtype": "", @@ -748,28 +740,37 @@ def get_dataset_info(dataset_path: str) -> object: object Dictionary with the information of the dataset """ - dataset = load_dataset(dataset_path=dataset_path) - total_rows = dataset.num_rows - total_columns = len(dataset.features) - splits = dataset.splits.get("split_indices", {}) + metadata_filepath = os.path.join(dataset_path, "splits.json") + if os.path.exists(metadata_filepath): + with open(metadata_filepath, "r") as f: + splits_data = json.load(f) + else: + splits_data = {"split_indices": {}} + + data_filepath = os.path.join(dataset_path, "data.arrow") + with pa.OSFile(data_filepath, "rb") as source: + reader = ipc.open_file(source) + schema = reader.schema + + total_rows = 0 + for i in range(reader.num_record_batches): + total_rows += reader.get_batch(i).num_rows + + splits = splits_data.get("split_indices", {}) train_indices = splits.get("train", []) test_indices = splits.get("test", []) val_indices = splits.get("validation", []) - train_size = len(train_indices) - test_size = len(test_indices) - val_size = len(val_indices) - dataset_info = { + return { "total_rows": total_rows, - "total_columns": total_columns, - "train_size": train_size, - "test_size": test_size, - "val_size": val_size, + "total_columns": len(schema), + "train_size": len(train_indices), + "test_size": len(test_indices), + "val_size": len(val_indices), "train_indices": train_indices, "test_indices": test_indices, "val_indices": val_indices, } - return dataset_info @beartype diff --git a/DashAI/back/dataloaders/classes/image_dataloader.py b/DashAI/back/dataloaders/classes/image_dataloader.py index d494ef4c..327f560d 100644 --- a/DashAI/back/dataloaders/classes/image_dataloader.py +++ b/DashAI/back/dataloaders/classes/image_dataloader.py @@ -1,11 +1,12 @@ """DashAI Image Dataloader.""" -import shutil from typing import Any, Dict from beartype import beartype from datasets import load_dataset +from DashAI.back.core.schema_fields import none_type, schema_field, string_field +from DashAI.back.core.schema_fields.base_schema import BaseSchema from DashAI.back.dataloaders.classes.dashai_dataset import ( DashAIDataset, to_dashai_dataset, @@ -13,10 +14,22 @@ from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader +class ImageDataloaderSchema(BaseSchema): + name: schema_field( + none_type(string_field()), + "", + ( + "Custom name to register your dataset. If no name is specified, " + "the name of the uploaded file will be used." + ), + ) # type: ignore + + class ImageDataLoader(BaseDataLoader): """Data loader for data from image files.""" COMPATIBLE_COMPONENTS = ["ImageClassificationTask"] + SCHEMA = ImageDataloaderSchema @beartype def load_data( @@ -46,11 +59,7 @@ def load_data( prepared_path = self.prepare_files(filepath_or_buffer, temp_path) if prepared_path[1] == "dir": - dataset = load_dataset( - "imagefolder", - data_dir=prepared_path[0], - ) - shutil.rmtree(prepared_path[0]) + dataset = load_dataset("imagefolder", data_dir=prepared_path[0]) else: raise Exception( "The image dataloader requires the input file to be a zip file." diff --git a/DashAI/back/job/dataset_job.py b/DashAI/back/job/dataset_job.py index 5ddef7cd..1bd50538 100644 --- a/DashAI/back/job/dataset_job.py +++ b/DashAI/back/job/dataset_job.py @@ -12,6 +12,7 @@ from DashAI.back.api.api_v1.schemas.datasets_params import DatasetParams from DashAI.back.api.utils import parse_params from DashAI.back.dataloaders.classes.dashai_dataset import save_dataset +from DashAI.back.dataloaders.classes.image_dataloader import ImageDataLoader from DashAI.back.dependencies.database.models import Dataset from DashAI.back.dependencies.registry import ComponentRegistry from DashAI.back.job.base_job import BaseJob, JobError @@ -70,13 +71,24 @@ def run( try: log.debug("Storing dataset in %s", folder_path) - new_dataset = dataloader.load_data( - filepath_or_buffer=str(file_path) if file_path is not None else url, - temp_path=str(temp_dir), - params=parsed_params.model_dump(), - ) - gc.collect() dataset_save_path = folder_path / "dataset" + if dataloader == ImageDataLoader: + new_dataset = dataloader.load_data( + filepath_or_buffer=str(file_path) + if file_path is not None + else url, + temp_path=str(dataset_save_path), + params=parsed_params.model_dump(), + ) + else: + new_dataset = dataloader.load_data( + filepath_or_buffer=str(file_path) + if file_path is not None + else url, + temp_path=str(temp_dir), + params=parsed_params.model_dump(), + ) + gc.collect() log.debug("Saving dataset in %s", str(dataset_save_path)) save_dataset(new_dataset, dataset_save_path) except Exception as e: diff --git a/tests/back/dataloaders/test_image_dataloaders.py b/tests/back/dataloaders/test_image_dataloaders.py index 46b3f2dd..dff5c097 100644 --- a/tests/back/dataloaders/test_image_dataloaders.py +++ b/tests/back/dataloaders/test_image_dataloaders.py @@ -14,7 +14,7 @@ def test_image_dataloader_from_zip(): dataset = image_dataloader.load_data( filepath_or_buffer=test_dataset_path, temp_path="tests/back/dataloaders/beans_dataset_small", - params={}, + params={"name": "beans_dataset_small"}, ) assert isinstance(dataset, DashAIDataset)