diff --git a/src/bids_validator/context.py b/src/bids_validator/context.py index 434018c..7b940a1 100644 --- a/src/bids_validator/context.py +++ b/src/bids_validator/context.py @@ -113,6 +113,7 @@ class Dataset: """A dataset object that loads properties on first access.""" tree: FileTree + schema: Namespace ignored: list[str] = attrs.field(factory=list) subjects: Subjects = attrs.field(init=False) @@ -129,14 +130,41 @@ def dataset_description(self) -> Namespace: @cached_property def modalities(self) -> list[str]: """List of modalities found in the dataset.""" - ... - return [] + result = set() + + modalities = self.schema.rules.modalities + for datatype in self.datatypes: + for mod_name, mod_dtypes in modalities.items(): + if datatype in mod_dtypes.datatypes: + result.add(mod_name) + + return list(result) @cached_property def datatypes(self) -> list[str]: """List of datatypes found in the dataset.""" - ... - return [] + return list(find_datatypes(self.tree, self.schema.objects.datatypes)) + + +def find_datatypes( + tree: FileTree, datatypes: Namespace, result: set[str] | None = None, max_depth: int = 2 +) -> set[str]: + """Recursively work through tree to find datatypes.""" + if result is None: + result = set() + + for child_name, child_obj in tree.children.items(): + if not child_obj.is_dir: + continue + + if child_name in datatypes.keys(): + result.add(child_name) + elif max_depth == 0: + continue + else: + result = find_datatypes(child_obj, datatypes, result, max_depth=max_depth - 1) + + return result @attrs.define diff --git a/tests/test_context.py b/tests/test_context.py index c58c6f7..0cc6c45 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,16 +1,33 @@ from bids_validator import context from bids_validator.types.files import FileTree +import pytest -def test_load(examples): +def test_load(examples, schema): tree = FileTree.read_from_filesystem(examples / 'synthetic') - ds = context.Dataset(tree) + ds = context.Dataset(tree, schema) assert ds.dataset_description.Name.startswith('Synthetic dataset') assert ds.subjects.participant_id == [f'sub-{i:02d}' for i in range(1, 6)] assert sorted(ds.subjects.sub_dirs) == [f'sub-{i:02d}' for i in range(1, 6)] + assert sorted(ds.datatypes) == ["anat", "beh", "func"] + assert sorted(ds.modalities) == ["beh", "mri"] +@pytest.mark.parametrize( + "depth, expected", + [ + (2, {"anat", "beh", "func"}), + (1, set()) + ]) +def test_find_datatypes(examples, schema, depth, expected): + tree = FileTree.read_from_filesystem(examples / 'synthetic') + datatypes = schema.objects.datatypes + + result = context.find_datatypes(tree, datatypes, max_depth=depth) + + assert result == expected + def test_fileparts(examples, schema): tree = FileTree.read_from_filesystem(examples / 'synthetic')