Skip to content
This repository was archived by the owner on Sep 11, 2024. It is now read-only.

Commit c6c9b88

Browse files
authored
Merge pull request #59 from climatepolicyradar/RND-479-dataset-filter-by-corpus
Dataset - filter_by_corpus
2 parents 0875e03 + 24d53a9 commit c6c9b88

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/cpr_data_access/models.py

+6
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,12 @@ def filter(self, attribute: str, value: Any) -> "Dataset":
10091009

10101010
return Dataset(**instance_attributes, documents=documents)
10111011

1012+
def filter_by_corpus(self, corpus_name: str) -> "Dataset":
1013+
"""Returns documents that are source from the corpus provided as per their document-id"""
1014+
return self.filter(
1015+
"document_id", lambda x: x.lower().startswith(corpus_name.lower())
1016+
)
1017+
10121018
def filter_by_language(self, language: str) -> "Dataset":
10131019
"""Return documents whose only language is the given language."""
10141020
return self.filter("languages", [language])

tests/test_models.py

+11
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ def test_dataset_filter_by_language(test_dataset):
181181
assert dataset.documents[1].languages == ["en"]
182182

183183

184+
def test_dataset_filter_by_corpus(test_dataset):
185+
"""Test Dataset.filter_by_corpus"""
186+
dataset = test_dataset.filter_by_corpus("UNFCCC")
187+
188+
assert len(dataset) == 0
189+
190+
dataset = test_dataset.filter_by_corpus("CCLW")
191+
192+
assert len(dataset) == 3
193+
194+
184195
def test_dataset_get_all_text_blocks(test_dataset):
185196
text_blocks = test_dataset.get_all_text_blocks()
186197
num_text_blocks = sum(

0 commit comments

Comments
 (0)