Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions openfl-workspace/torch/histology_s3/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,32 @@
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from collections.abc import Iterable
from logging import getLogger
import os
import sys
from collections.abc import Iterable
from logging import getLogger


from openfl.federated import PyTorchDataLoader
import numpy as np
from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
from openfl.utilities.path_check import is_directory_traversal
import torch
from torch.utils.data import random_split
from torchvision.transforms import ToTensor

from openfl.federated import PyTorchDataLoader
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder
from openfl.utilities.path_check import is_directory_traversal

logger = getLogger(__name__)


class PyTorchHistologyVerifiableDataLoader(PyTorchDataLoader):
"""PyTorch data loader for Histology dataset."""

def __init__(self, data_path, batch_size, **kwargs):
def __init__(self, data_path=None, batch_size=32, **kwargs):
"""Instantiate the data object.

Args:
data_path: The file path to the data
data_path: The file path to the data. If None, initialize for model creation only.
batch_size: The batch size of the data loader
**kwargs: Additional arguments, passed to super init
and load_mnist_shard
Expand Down Expand Up @@ -61,17 +59,19 @@ def __init__(self, data_path, batch_size, **kwargs):
else:
logger.info("The dataset is valid.")

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
verifible_dataset_info=verifible_dataset_info, verify_dataset_items=verify_dataset_items, **kwargs)
X_train, y_train, X_valid, y_valid = load_histology_shard(
verifible_dataset_info=verifible_dataset_info,
verify_dataset_items=verify_dataset_items,
feature_shape=self.feature_shape,
num_classes=self.num_classes,
**kwargs
)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes


def get_feature_shape(self):
"""Returns the shape of an example feature array.

Expand Down Expand Up @@ -101,7 +101,6 @@ def get_verifiable_dataset_info(self, data_path):
Raises:
SystemExit: If `data_path` is invalid or missing `datasources.json`.
"""
"""Return the verifiable dataset info object for the given data sources."""
if data_path and is_directory_traversal(data_path):
logger.error("Data path is out of the openfl workspace scope.")
if not os.path.isdir(data_path):
Expand Down Expand Up @@ -152,7 +151,8 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
n_train = int(train_split_ratio * len(dataset))
n_valid = len(dataset) - n_train
ds_train, ds_val = random_split(
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0))
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0)
)

# create the shards
X_train, y_train = list(zip(*ds_train))
Expand All @@ -164,41 +164,40 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
return (X_train, y_train), (X_valid, y_valid)



def load_histology_shard(verifible_dataset_info, verify_dataset_items,
def load_histology_shard(verifible_dataset_info, verify_dataset_items, feature_shape=None, num_classes=None,
categorical=False, channels_last=False, **kwargs):
"""
Load the Histology dataset.

Args:
data_path (str): path to data directory
verifible_dataset_info (VerifiableDatasetInfo): The verifiable dataset info object.
verify_dataset_items (bool): True = verify the dataset items while loading data
feature_shape (list, optional): The shape of input features.
num_classes (int, optional): Number of classes.
categorical (bool): True = convert the labels to one-hot encoded
vectors (Default = True)
channels_last (bool): True = The input images have the channels
last (Default = True)
**kwargs: Additional parameters to pass to the function

Returns:
list: The input shape
int: The number of classes
numpy.ndarray: The training data
numpy.ndarray: The training labels
numpy.ndarray: The validation data
numpy.ndarray: The validation labels
"""
img_rows, img_cols = 150, 150
num_classes = 8
img_rows, img_cols = feature_shape[1], feature_shape[2]

(X_train, y_train), (X_valid, y_valid) = _load_raw_data(verifible_dataset_info, verify_dataset_items, **kwargs)
(X_train, y_train), (X_valid, y_valid) = _load_raw_data(
verifible_dataset_info, verify_dataset_items, **kwargs
)

if channels_last:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3)
input_shape = (img_rows, img_cols, 3)
else:
X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols)
X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols)
input_shape = (3, img_rows, img_cols)

logger.info(f'Histology > X_train Shape : {X_train.shape}')
logger.info(f'Histology > y_train Shape : {y_train.shape}')
Expand All @@ -210,4 +209,4 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items,
y_train = np.eye(num_classes)[y_train]
y_valid = np.eye(num_classes)[y_valid]

return input_shape, num_classes, X_train, y_train, X_valid, y_valid
return X_train, y_train, X_valid, y_valid
65 changes: 44 additions & 21 deletions openfl/federated/data/sources/data_sources_json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

class DataSourcesJsonParser:
@staticmethod
def parse(json_string: str) -> VerifiableDatasetInfo:
def parse(
json_string: str, label="", metadata="", check_dir_traversal=False
) -> VerifiableDatasetInfo:
"""
Parse a JSON string into a dictionary.

Expand All @@ -31,48 +33,69 @@ def parse(json_string: str) -> VerifiableDatasetInfo:
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}")

datasources = DataSourcesJsonParser.process_data_sources(data)
datasources = DataSourcesJsonParser.process_data_sources(data, check_dir_traversal)
if not datasources:
raise ValueError("No data sources were found.")
return VerifiableDatasetInfo(
data_sources=datasources,
label="",
label=label,
metadata=metadata,
)

@staticmethod
def process_data_sources(data):
def process_data_sources(data, check_dir_traversal=False):
"""Process and validate data sources."""
cwd = os.getcwd()
datasources = []
local_datasources = {}
for source_name, source_info in data.items():
source_type = source_info.get("type", None)
if source_type is None:
raise ValueError(f"Missing 'type' key in data source configuration: {source_info}")
params = source_info.get("params", {})
if source_type == "local":
datasources.append(
DataSourcesJsonParser.process_local_source(source_name, params, cwd)
)
if source_type == "fs":
local_datasources[source_name] = params
elif source_type == "s3":
datasources.append(DataSourcesJsonParser.process_s3_source(source_name, params))
elif source_type == "azure_blob":
elif source_type == "ab":
datasources.append(
DataSourcesJsonParser.process_azure_blob_source(source_name, params)
)
if local_datasources:
DataSourcesJsonParser.process_local_sources(
local_datasources, datasources, check_dir_traversal
)
return [ds for ds in datasources if ds]

@staticmethod
def process_local_source(source_name, params, cwd):
"""Process a local data source."""
path = params.get("path", None)
if not path:
raise ValueError(f"Missing 'path' parameter for local data source '{source_name}'")
abs_path = os.path.abspath(path)
rel_path = os.path.relpath(abs_path, cwd)
if rel_path and not is_directory_traversal(rel_path):
return LocalDataSource(source_name, rel_path, base_path=Path("."))
else:
raise ValueError(f"Invalid path for local data source '{source_name}': {path}.")
def process_local_sources(local_datasources, datasources, check_dir_traversal=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is also supposed to be staticmethod?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right.

"""Process and validate local data sources."""
absolute_paths = {}
for source_name, params in local_datasources.items():
if "path" not in params:
raise ValueError(
f"Missing required field 'path' for local data source '{source_name}'."
)
absolute_paths[source_name] = os.path.realpath(params.get("path"))

# The reason we use common base_dir and source_path relative to that base
# is to simplify path management in containerized environments, such as Docker.
# By using a common base_dir, we can ensure that paths remain consistent
# when mounting volumes, as only the base_dir needs to be adjusted to point
# to the mount path inside the container.
# This way, we only need to adjust the base_dir to point to the mount path.
base_dir = os.path.commonpath(absolute_paths.values())
for source_name, data_path in absolute_paths.items():
relative_path = os.path.relpath(data_path, base_dir)
if check_dir_traversal and is_directory_traversal(data_path):
raise ValueError(
f"Invalid path for local data source '{source_name}': {data_path}."
f" Data path is out of the openfl workspace scope."
)
datasources.append(
LocalDataSource(
name=source_name, source_path=Path(relative_path), base_path=base_dir
)
)

@staticmethod
def process_s3_source(source_name, params):
Expand Down
10 changes: 6 additions & 4 deletions openfl/interface/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ def register_data_path(collaborator_name, data_path=None, silent=False):
type=ClickPath(exists=True),
help=(
"Path to directory containing sources.json file defining the data sources of the dataset. "
"This file should contain a JSON object with the data sources to be registered. For 'local'"
" type, 'params' must include: 'path'. For 's3' type, 'params' must include: 'uri', "
"'access_key_env_name', 'secret_key_env_name', 'secret_name', and optionally 'endpoint'."
"This file should contain a JSON object with the data sources to be registered. For local "
"data source, 'type' is 'fs', and 'params' must include: 'path'. For 's3' type, 'params' "
"must include: 'uri', 'access_key_env_name', 'secret_key_env_name', 'secret_name', and "
"optionally 'endpoint'. For azure_blob, 'type' is 'ab', and 'params' must include: "
"'connection_string', 'container_name', and optionally 'folder_prefix'."
Comment on lines +234 to +238
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion would be to point to OpenFL documentation URL that describes data sources in detail with examples, in addition to the format of JSON files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. @teoparvanov what doc file should we insert that info to?

),
)
def calchash(data_path):
Expand All @@ -258,7 +260,7 @@ def calchash(data_path):
sys.exit(1)
with open(datasources_json_path, "r", encoding="utf-8") as file:
data = file.read()
vds = DataSourcesJsonParser.parse(data)
vds = DataSourcesJsonParser.parse(data, check_dir_traversal=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, can we check_dir_traversal before calling this function? I don't think the class needs to know about whether it is in a valid subdirectory...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. That means we go over the json one more time to look for the local data sources. Makes sense? currently happens only in calchash

root_hash = vds.create_dataset_hash()
hash_file_path = os.path.join(data_path, "hash.txt")
with open(hash_file_path, "w", encoding="utf-8") as hash_file:
Expand Down
Loading