Skip to content

Commit 9d37e90

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implement MultiTaskDataset.__eq__ (#2594)
Summary: Pull Request resolved: #2594 Previously, this would fallback to `SupervisedDataset.__eq__`, which uses `self.X` for comparison. If the underlying datasets have heterogeneous feature sets, `self.X` errors out. The new `MultiTaskDataset.__eq__` resolves this issue by comparing the underlying datasets one by one. Reviewed By: Balandat Differential Revision: D64911436 fbshipit-source-id: ecb7343d86c4526d06f61725c1663e50f1f1902f
1 parent e7539db commit 9d37e90

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

Diff for: botorch/utils/datasets.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,14 @@ def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDatas
492492
outcome_names=[outcome_name],
493493
)
494494

495+
def __eq__(self, other: Any) -> bool:
496+
return (
497+
type(other) is type(self)
498+
and self.datasets == other.datasets
499+
and self.target_outcome_name == other.target_outcome_name
500+
and self.task_feature_index == other.task_feature_index
501+
)
502+
495503

496504
class ContextualDataset(SupervisedDataset):
497505
"""This is a contextual dataset that is constructed from either a single
@@ -548,7 +556,7 @@ def Y(self) -> Tensor:
548556
return torch.cat(Ys, dim=-1)
549557

550558
@property
551-
def Yvar(self) -> Tensor:
559+
def Yvar(self) -> Tensor | None:
552560
"""Concatenates the Yvars from the child datasets to create the Y expected
553561
by LCEM model if there are multiple datasets; Or return the Yvar expected
554562
by LCEA model if there is only one dataset.

Diff for: test/utils/test_datasets.py

+11
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,17 @@ def test_multi_task(self):
354354
):
355355
mt_dataset.X
356356

357+
# Test equality.
358+
self.assertEqual(mt_dataset, mt_dataset)
359+
self.assertNotEqual(mt_dataset, dataset_5)
360+
self.assertNotEqual(
361+
mt_dataset, MultiTaskDataset(datasets=[dataset_1], target_outcome_name="y")
362+
)
363+
self.assertNotEqual(
364+
mt_dataset,
365+
MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"),
366+
)
367+
357368
def test_contextual_datasets(self):
358369
num_contexts = 3
359370
feature_names = [f"x_c{i}" for i in range(num_contexts)]

0 commit comments

Comments
 (0)