Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ImageDataLoader and improve efficiency in endpoints that load datasets #248

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions DashAI/back/api/api_v1/endpoints/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import shutil
from typing import Any, Dict

import pyarrow as pa
import pyarrow.ipc as ipc
from fastapi import APIRouter, Depends, Response, status
from fastapi.exceptions import HTTPException
from kink import di, inject
Expand All @@ -11,10 +13,8 @@

from DashAI.back.api.api_v1.schemas.datasets_params import DatasetUpdateParams
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
get_columns_spec,
get_dataset_info,
load_dataset,
update_columns_spec,
)
from DashAI.back.dependencies.database.models import Dataset
Expand Down Expand Up @@ -126,8 +126,16 @@ async def get_sample(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dataset not found",
)
dataset: DashAIDataset = load_dataset(f"{file_path}/dataset")
sample = dataset.sample(n=10)

arrow_path = os.path.join(file_path, "dataset", "data.arrow")

with pa.OSFile(arrow_path, "rb") as source:
reader = ipc.open_file(source)
batch = reader.get_batch(0)
sample_size = min(10, batch.num_rows)
sample_batch = batch.slice(0, sample_size)
sample = sample_batch.to_pydict()

except exc.SQLAlchemyError as e:
logger.exception(e)
raise HTTPException(
Expand Down
68 changes: 43 additions & 25 deletions DashAI/back/api/api_v1/endpoints/experiments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import os
from typing import Union

import pyarrow as pa
import pyarrow.ipc as ipc
from fastapi import APIRouter, Depends, Response, status
from fastapi.exceptions import HTTPException
from kink import di, inject
Expand All @@ -11,10 +14,7 @@
ColumnsValidationParams,
ExperimentParams,
)
from DashAI.back.dataloaders.classes.dashai_dataset import (
get_column_names_from_indexes,
load_dataset,
)
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
from DashAI.back.dependencies.database.models import Dataset, Experiment
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.tasks.base_task import BaseTask
Expand Down Expand Up @@ -101,6 +101,7 @@ async def validate_columns(
component_registry: ComponentRegistry = Depends(lambda: di["component_registry"]),
session_factory: sessionmaker = Depends(lambda: di["session_factory"]),
):
"""Validate if dataset columns are compatible with a task."""
with session_factory() as db:
try:
dataset = db.get(Dataset, params.dataset_id)
Expand All @@ -109,35 +110,48 @@ async def validate_columns(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dataset not found",
)
datasetdict = load_dataset(f"{dataset.file_path}/dataset")
if not datasetdict:

dataset_path = f"{dataset.file_path}/dataset"
data_filepath = os.path.join(dataset_path, "data.arrow")
with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
batch = reader.get_batch(0)
sample_size = min(5, batch.num_rows)
sample_batch = batch.slice(0, sample_size)

table = pa.Table.from_batches([sample_batch])
minimal_dataset = DashAIDataset(table)

column_names = minimal_dataset.column_names

if max(params.inputs_columns + params.outputs_columns) > len(column_names):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Error while loading the dataset.",
status_code=status.HTTP_400_BAD_REQUEST,
detail="Column index out of range",
)
inputs_names = get_column_names_from_indexes(
dataset=datasetdict, indexes=params.inputs_columns
)
outputs_names = get_column_names_from_indexes(
dataset=datasetdict, indexes=params.outputs_columns
)

inputs_names = [column_names[i - 1] for i in params.inputs_columns]
outputs_names = [column_names[i - 1] for i in params.outputs_columns]

except exc.SQLAlchemyError as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal database error",
) from e

if params.task_name not in component_registry:
raise HTTPException(
status_code=404,
detail=f"Task {params.task_name} not found in the registry.",
)

task: BaseTask = component_registry[params.task_name]["class"]()
validation_response = {}

try:
prepared_dataset = task.prepare_for_task(
datasetdict=datasetdict,
datasetdict=minimal_dataset,
outputs_columns=outputs_names,
)
task.validate_dataset_for_task(
Expand Down Expand Up @@ -186,18 +200,22 @@ async def create_experiment(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found"
)
datasetdict = load_dataset(f"{dataset.file_path}/dataset")
if not datasetdict:
dataset_path = f"{dataset.file_path}/dataset"
data_filepath = os.path.join(dataset_path, "data.arrow")

with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
schema = reader.schema
column_names = schema.names

if max(params.input_columns + params.output_columns) > len(column_names):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Error while loading the dataset.",
status_code=status.HTTP_400_BAD_REQUEST,
detail="Column index out of range",
)
inputs_columns = get_column_names_from_indexes(
dataset=datasetdict, indexes=params.input_columns
)
outputs_columns = get_column_names_from_indexes(
dataset=datasetdict, indexes=params.output_columns
)

inputs_columns = [column_names[i - 1] for i in params.input_columns]
outputs_columns = [column_names[i - 1] for i in params.output_columns]
experiment = Experiment(
dataset_id=params.dataset_id,
task_name=params.task_name,
Expand Down
69 changes: 35 additions & 34 deletions DashAI/back/dataloaders/classes/dashai_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyarrow.ipc as ipc
from beartype import beartype
from datasets import ClassLabel, Dataset, DatasetDict, Value, concatenate_datasets
from datasets.features import Features
from sklearn.model_selection import train_test_split


Expand Down Expand Up @@ -93,21 +94,6 @@ def keys(self) -> List[str]:
return list(self.splits["split_indices"].keys())
return []

@beartype
def save_to_disk(self, dataset_path: Union[str, os.PathLike]) -> None:
"""
Overrides the default save_to_disk method to save the dataset as
a single directory with:
- "data.arrow": the dataset's Arrow table.
- "splits.json": the dataset's splits (e.g., original split indices).

Parameters
----------
dataset_path : Union[str, os.PathLike]
path where the dataset will be saved
"""
save_dataset(self, dataset_path)

@beartype
def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset":
"""Change the type of some columns.
Expand Down Expand Up @@ -678,16 +664,22 @@ def get_columns_spec(dataset_path: str) -> Dict[str, Dict]:
Dict
Dict with the columns and types
"""
dataset = load_dataset(dataset_path)
dataset_features = dataset.features

data_filepath = os.path.join(dataset_path, "data.arrow")
with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
schema = reader.schema

features = Features.from_arrow_schema(schema)

column_types = {}
for column in dataset_features:
if dataset_features[column]._type == "Value":
for column, feature in features.items():
if feature._type == "Value":
column_types[column] = {
"type": "Value",
"dtype": dataset_features[column].dtype,
"dtype": feature.dtype,
}
elif dataset_features[column]._type == "ClassLabel":
elif feature._type == "ClassLabel":
column_types[column] = {
"type": "Classlabel",
"dtype": "",
Expand Down Expand Up @@ -748,28 +740,37 @@ def get_dataset_info(dataset_path: str) -> object:
object
Dictionary with the information of the dataset
"""
dataset = load_dataset(dataset_path=dataset_path)
total_rows = dataset.num_rows
total_columns = len(dataset.features)
splits = dataset.splits.get("split_indices", {})
metadata_filepath = os.path.join(dataset_path, "splits.json")
if os.path.exists(metadata_filepath):
with open(metadata_filepath, "r") as f:
splits_data = json.load(f)
else:
splits_data = {"split_indices": {}}

data_filepath = os.path.join(dataset_path, "data.arrow")
with pa.OSFile(data_filepath, "rb") as source:
reader = ipc.open_file(source)
schema = reader.schema

total_rows = 0
for i in range(reader.num_record_batches):
total_rows += reader.get_batch(i).num_rows

splits = splits_data.get("split_indices", {})
train_indices = splits.get("train", [])
test_indices = splits.get("test", [])
val_indices = splits.get("validation", [])
train_size = len(train_indices)
test_size = len(test_indices)
val_size = len(val_indices)

dataset_info = {
return {
"total_rows": total_rows,
"total_columns": total_columns,
"train_size": train_size,
"test_size": test_size,
"val_size": val_size,
"total_columns": len(schema),
"train_size": len(train_indices),
"test_size": len(test_indices),
"val_size": len(val_indices),
"train_indices": train_indices,
"test_indices": test_indices,
"val_indices": val_indices,
}
return dataset_info


@beartype
Expand Down
21 changes: 15 additions & 6 deletions DashAI/back/dataloaders/classes/image_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
"""DashAI Image Dataloader."""

import shutil
from typing import Any, Dict

from beartype import beartype
from datasets import load_dataset

from DashAI.back.core.schema_fields import none_type, schema_field, string_field
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader


class ImageDataloaderSchema(BaseSchema):
name: schema_field(
none_type(string_field()),
"",
(
"Custom name to register your dataset. If no name is specified, "
"the name of the uploaded file will be used."
),
) # type: ignore


class ImageDataLoader(BaseDataLoader):
"""Data loader for data from image files."""

COMPATIBLE_COMPONENTS = ["ImageClassificationTask"]
SCHEMA = ImageDataloaderSchema

@beartype
def load_data(
Expand Down Expand Up @@ -46,11 +59,7 @@ def load_data(
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)

if prepared_path[1] == "dir":
dataset = load_dataset(
"imagefolder",
data_dir=prepared_path[0],
)
shutil.rmtree(prepared_path[0])
dataset = load_dataset("imagefolder", data_dir=prepared_path[0])
else:
raise Exception(
"The image dataloader requires the input file to be a zip file."
Expand Down
24 changes: 18 additions & 6 deletions DashAI/back/job/dataset_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from DashAI.back.api.api_v1.schemas.datasets_params import DatasetParams
from DashAI.back.api.utils import parse_params
from DashAI.back.dataloaders.classes.dashai_dataset import save_dataset
from DashAI.back.dataloaders.classes.image_dataloader import ImageDataLoader
from DashAI.back.dependencies.database.models import Dataset
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.job.base_job import BaseJob, JobError
Expand Down Expand Up @@ -70,13 +71,24 @@ def run(

try:
log.debug("Storing dataset in %s", folder_path)
new_dataset = dataloader.load_data(
filepath_or_buffer=str(file_path) if file_path is not None else url,
temp_path=str(temp_dir),
params=parsed_params.model_dump(),
)
gc.collect()
dataset_save_path = folder_path / "dataset"
if dataloader == ImageDataLoader:
new_dataset = dataloader.load_data(
filepath_or_buffer=str(file_path)
if file_path is not None
else url,
temp_path=str(dataset_save_path),
params=parsed_params.model_dump(),
)
else:
new_dataset = dataloader.load_data(
filepath_or_buffer=str(file_path)
if file_path is not None
else url,
temp_path=str(temp_dir),
params=parsed_params.model_dump(),
)
gc.collect()
log.debug("Saving dataset in %s", str(dataset_save_path))
save_dataset(new_dataset, dataset_save_path)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion tests/back/dataloaders/test_image_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_image_dataloader_from_zip():
dataset = image_dataloader.load_data(
filepath_or_buffer=test_dataset_path,
temp_path="tests/back/dataloaders/beans_dataset_small",
params={},
params={"name": "beans_dataset_small"},
)

assert isinstance(dataset, DashAIDataset)
Expand Down