Skip to content

Commit

Permalink
Check that the correct test dataset is being used for each trained re…
Browse files Browse the repository at this point in the history
…ference.
  • Loading branch information
LTLA committed Dec 14, 2024
1 parent 3a9204f commit 64aeafc
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/singler/classify_integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def classify_integrated(
"""
if isinstance(test_data, summarizedexperiment.SummarizedExperiment):
test_data = test_data.assay(assay_type)

if test_data.shape[0] != integrated_prebuilt._test_num_features: # TODO: move to singlepp.
raise ValueError("number of rows in 'test_data' is not consistent with 'test_features=' used to create 'integrated_prebuilt'")

ref_labs = integrated_prebuilt.reference_labels

# Applying the sanity checks.
Expand Down
4 changes: 4 additions & 0 deletions src/singler/classify_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def classify_single(
"""
if isinstance(test_data, summarizedexperiment.SummarizedExperiment):
test_data = test_data.assay(assay_type)

if test_data.shape[0] != ref_prebuilt._test_num_features: # TODO: move to singlepp
raise ValueError("number of rows in 'test_data' is not consistent with 'test_features=' used to create 'ref_prebuilt'")

test_ptr = mattress.initialize(test_data)

best, raw_scores, delta = lib.classify_single(
Expand Down
8 changes: 5 additions & 3 deletions src/singler/train_integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ class TrainedIntegratedReferences:
"""Object containing integrated references, typically constructed by
:py:meth:`~singler.train_integrated.train_integrated`."""

def __init__(self, ptr, ref_names, ref_labels):
def __init__(self, ptr: int, ref_names: Optional[Sequence], ref_labels: list, test_num_features: int):
self._ptr = ptr
self._names = ref_names
self._labels = ref_labels
self._test_num_features = test_num_features # TODO: move to singlepp.

@property
def reference_names(self) -> Union[Sequence[str], None]:
def reference_names(self) -> Union[Sequence, None]:
"""Sequence containing the names of the references. Alternatively
``None``, if no names were supplied."""
return self._names
Expand Down Expand Up @@ -98,5 +99,6 @@ def train_integrated(
return TrainedIntegratedReferences(
ptr=ibuilt,
ref_names=ref_names,
ref_labels=[x.labels for x in ref_prebuilt]
ref_labels=[x.labels for x in ref_prebuilt],
test_num_features = len(test_features),
)
11 changes: 8 additions & 3 deletions src/singler/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ class TrainedSingleReference:

def __init__(
self,
ptr,
full_data,
ptr: int,
full_data: Any,
full_label_codes: numpy.ndarray,
labels: Sequence,
features: Sequence,
markers: dict[Any, dict[Any, Sequence]]
markers: dict[Any, dict[Any, Sequence]],
test_num_features: int,
):
self._ptr = ptr
self._full_data = full_data
self._full_label_codes = full_label_codes
self._features = features
self._labels = labels
self._markers = markers
self._test_num_features = test_num_features # TODO: move to singlepp.

def num_markers(self) -> int:
"""
Expand Down Expand Up @@ -233,10 +235,12 @@ def train_single(
if test_features is None:
test_features_idx = numpy.array(range(len(ref_features)), dtype=numpy.uint32)
ref_features_idx = numpy.array(range(len(ref_features)), dtype=numpy.uint32)
test_num_features = len(ref_features)
else:
common_features = _stable_intersect(test_features, ref_features)
test_features_idx = biocutils.match(common_features, test_features, dtype=numpy.uint32)
ref_features_idx = biocutils.match(common_features, ref_features, dtype=numpy.uint32)
test_num_features = len(test_features)

ref_ptr = mattress.initialize(ref_data)
builder, _ = knncolle.define_builder(nn_parameters)
Expand All @@ -255,6 +259,7 @@ def train_single(
labels = unique_labels,
features = ref_features,
markers = markers,
test_num_features = test_num_features,
)


Expand Down

0 comments on commit 64aeafc

Please sign in to comment.