Skip to content

Commit 2541b32

Browse files
Add tests for dataset extraction script
1 parent 3af9ab2 commit 2541b32

2 files changed

Lines changed: 102 additions & 2 deletions

File tree

etc/scripts/dataset_pipeline/build_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def main(rules_dir, output_dir):
186186
click.echo(f' train: {len(splits["train"])} val: {len(splits["val"])} test: {len(splits["test"])}')
187187
click.echo(f' output: {out_dir}')
188188

189-
# stuff to do(follow up commits):
190-
# tests to be added in script
189+
191190
if __name__ == '__main__':
192191
main()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# tests for build_dataset.py
2+
import sys
3+
from pathlib import Path
4+
5+
sys.path.insert(0, str(Path(__file__).parent))
6+
from build_dataset import normalize_phrase, tag_tokens, assign_splits
7+
8+
9+
class TestNormalizePhrase:
10+
11+
def test_html_entities(self):
12+
assert normalize_phrase('<b>MIT</b>') == 'MIT'
13+
assert normalize_phrase('the "License"') == 'the "License"'
14+
assert normalize_phrase('foo & bar') == 'foo & bar'
15+
16+
def test_preserves_urls_in_angle_brackets(self):
17+
result = normalize_phrase('<http://example.com/LICENSE>')
18+
assert result == 'http://example.com/LICENSE'
19+
20+
def test_strips_xml_tags(self):
21+
assert normalize_phrase('<name>Apache 2.0</name>') == 'Apache 2.0'
22+
assert normalize_phrase('<license>MIT</license>') == 'MIT'
23+
24+
def test_strips_backticks(self):
25+
assert normalize_phrase('`MIT License`') == 'MIT License'
26+
27+
def test_collapses_whitespace(self):
28+
assert normalize_phrase('GNU General\n Public License') == 'GNU General Public License'
29+
30+
def test_strips_trailing_punct(self):
31+
assert normalize_phrase('Apache 2.0.') == 'Apache 2.0'
32+
assert normalize_phrase(',MIT,') == 'MIT'
33+
34+
def test_empty_after_strip(self):
35+
assert normalize_phrase('<foo>') == ''
36+
assert normalize_phrase('...') == ''
37+
38+
39+
class TestTagTokens:
40+
41+
def test_single_phrase(self):
42+
tokens, labels = tag_tokens('under the {{Apache License}} terms')
43+
assert tokens == ['under', 'the', 'Apache', 'License', 'terms']
44+
assert labels == ['O', 'O', 'B-REQ', 'E-REQ', 'O']
45+
46+
def test_single_word_phrase(self):
47+
tokens, labels = tag_tokens('use {{MIT}} license')
48+
assert tokens == ['use', 'MIT', 'license']
49+
assert labels == ['O', 'S-REQ', 'O']
50+
51+
def test_multiple_phrases(self):
52+
tokens, labels = tag_tokens('{{Apache}} and {{MIT}} stuff')
53+
assert tokens == ['Apache', 'and', 'MIT', 'stuff']
54+
assert labels == ['S-REQ', 'O', 'S-REQ', 'O']
55+
56+
def test_long_phrase(self):
57+
tokens, labels = tag_tokens('{{GNU General Public License}}')
58+
assert tokens == ['GNU', 'General', 'Public', 'License']
59+
assert labels == ['B-REQ', 'I-REQ', 'I-REQ', 'E-REQ']
60+
61+
def test_no_markers(self):
62+
tokens, labels = tag_tokens('released under the license')
63+
assert tokens == ['released', 'under', 'the', 'license']
64+
assert labels == ['O', 'O', 'O', 'O']
65+
66+
def test_alignment(self):
67+
tokens, labels = tag_tokens('licensed under {{Apache License}} or {{MIT}}')
68+
assert len(tokens) == len(labels)
69+
70+
def test_empty_input(self):
71+
tokens, labels = tag_tokens('')
72+
assert tokens == []
73+
assert labels == []
74+
75+
def test_empty_markers_ignored(self):
76+
tokens, labels = tag_tokens('licensed under {{}} the GPL')
77+
assert tokens == ['licensed', 'under', 'the', 'GPL']
78+
assert labels == ['O', 'O', 'O', 'O']
79+
80+
81+
class TestAssignSplits:
82+
83+
def test_light_expressions_no_leakage(self):
84+
results = []
85+
for i in range(5):
86+
for j in range(10):
87+
results.append({'license_expression': f'license-{i}', 'identifier': f'rule_{i}_{j}.RULE'})
88+
89+
heavy, assignment = assign_splits(results)
90+
assert len(heavy) == 0
91+
assert len(assignment) == 5
92+
assert all(s in ('train', 'val', 'test') for s in assignment.values())
93+
94+
def test_heavy_expressions_detected(self):
95+
results = [{'license_expression': 'mit', 'identifier': f'mit_{i}.RULE'} for i in range(100)]
96+
results += [{'license_expression': 'rare-1.0', 'identifier': 'rare_1.RULE'}]
97+
98+
heavy, assignment = assign_splits(results)
99+
assert 'mit' in heavy
100+
assert 'rare-1.0' not in heavy
101+
assert 'rare-1.0' in assignment

0 commit comments

Comments
 (0)