Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dropout for training samples with min mean value below a configured threshold. #158

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions usl_models/usl_models/flood_ml/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""tf.data.Datasets for training FloodML model on CityCAT data."""

import dataclasses
import logging
import random
from typing import Any, Iterator, Tuple
Expand All @@ -16,6 +17,22 @@
from usl_models.flood_ml import model


@dataclasses.dataclass
class Filters:
"""Dataset filter settings."""

# Minimum mean value of a label to apply for filtering out sparse data.
min_mean_value: float = 0.0
# Ratio of samples to skip below the min_mean_value threshold [0.0, 1.0].
min_mean_value_dropout_rate: float = 0.0

def filter(self, label: tf.Tensor) -> bool:
"""Returns true if the sample should be filtered."""
return (tf.reduce_mean(label) < self.min_mean_value) and (
random.random() < self.min_mean_value_dropout_rate
)


def geospatial_dataset_signature() -> tf.TensorSpec:
return tf.TensorSpec(
shape=(
Expand Down Expand Up @@ -55,6 +72,7 @@ def load_dataset(
max_chunks: int | None = None,
firestore_client: firestore.Client = None,
storage_client: storage.Client | None = None,
filters: Filters | None = None,
) -> tf.data.Dataset:
"""Creates a dataset which generates chunks for the flood model.

Expand All @@ -75,7 +93,9 @@ def load_dataset(
If `None` (default) yields all examples from the simulations.
firestore_client: The client to use when interacting with Firestore.
storage_client: The client to use when interacting with Cloud Storage.
filters: Filters to be applied to the training data.
"""
filters = filters or Filters()
firestore_client = firestore_client or firestore.Client()
storage_client = storage_client or storage.Client()

Expand All @@ -91,6 +111,8 @@ def generator():
max_chunks,
dataset_split,
):
if filters.filter(labels):
continue
yield model_input, labels

# Create the dataset for this simulation
Expand Down Expand Up @@ -124,6 +146,7 @@ def load_dataset_windowed(
max_chunks: int | None = None,
firestore_client: firestore.Client | None = None,
storage_client: storage.Client | None = None,
filters: Filters | None = None,
) -> tf.data.Dataset:
"""Creates a dataset which generates chunks for flood model training.

Expand All @@ -146,7 +169,9 @@ def load_dataset_windowed(
If `None` (default) yields all examples from the simulations.
firestore_client: The client to use when interacting with Firestore.
storage_client: The client to use when interacting with Cloud Storage.
filters: Filters to be applied to the training data.
"""
filters = filters or Filters()
firestore_client = firestore_client or firestore.Client()
storage_client = storage_client or storage.Client()

Expand All @@ -165,6 +190,8 @@ def generator():
for window_input, window_label in _generate_windows(
model_input, labels, n_flood_maps
):
if filters.filter(window_label):
continue
yield (window_input, window_label)

dataset = tf.data.Dataset.from_generator(
Expand Down
Loading