Skip to content

Commit

Permalink
fix: empty dataset splits (#73)
Browse files Browse the repository at this point in the history
If a dataset is small enough or the values for the train/val/test splits
is very small, it may be possible that no piece of data is assigned to
that split. This change enforces that no split is empty, while it may
have size 1 and thus not really useful, but it won't break the execution
and will leave the responsibility to the user to take action in order to
fix the disparity caused by his/her actions or the very small dataset
size.
  • Loading branch information
eloy-encord authored May 8, 2024
1 parent 219f89f commit 76c6cb2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tti_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def build_command(
embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True)
print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`")
except Exception as e:
print(f"Failed to build embeddings for this bastard: {embd_defn}")
print(f"Failed to build embeddings for {embd_defn} on the specified split {split}")
print(e)
import traceback

Expand Down
4 changes: 2 additions & 2 deletions tti_eval/dataset/types/encord_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _setup(
ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH")
if ssh_key_path is None:
raise ValueError(
"The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing."
"Please set one of them to proceed"
"The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing. "
"Please set one of them to proceed."
)
client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path)
self._project = client.get_project(project_hash)
Expand Down
15 changes: 12 additions & 3 deletions tti_eval/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,25 @@ def simple_random_split(
:raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1,
or if `train_split` or `validation_split` are less than 0.
"""
if dataset_size < 3:
raise ValueError(f"Expected a dataset with size at least 3, got {dataset_size}")

if train_split < 0 or validation_split < 0:
raise ValueError(f"Expected positive splits, got ({train_split=}, {validation_split=})")
if train_split + validation_split > 1:
if train_split + validation_split >= 1:
raise ValueError(
f"Expected `train_split` and `validation_split` sum between 0 and 1, got {train_split + validation_split}"
)
rng = np.random.default_rng(seed)
selection = rng.permutation(dataset_size)
train_size = int(dataset_size * train_split)
validation_size = int(dataset_size * validation_split)
train_size = max(1, int(dataset_size * train_split))
validation_size = max(1, int(dataset_size * validation_split))
# Ensure that the TEST split has at least an element
if train_size + validation_size == dataset_size:
if train_size > 1:
train_size -= 1
if validation_size > 1:
validation_size -= 1
return {
Split.TRAIN: selection[:train_size],
Split.VALIDATION: selection[train_size : train_size + validation_size],
Expand Down

0 comments on commit 76c6cb2

Please sign in to comment.