|
| 1 | +# tests for build_dataset.py |
| 2 | +import sys |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +sys.path.insert(0, str(Path(__file__).parent)) |
| 6 | +from build_dataset import normalize_phrase, tag_tokens, assign_splits |
| 7 | + |
| 8 | + |
| 9 | +class TestNormalizePhrase: |
| 10 | + |
| 11 | + def test_html_entities(self): |
| 12 | + assert normalize_phrase('<b>MIT</b>') == 'MIT' |
| 13 | + assert normalize_phrase('the "License"') == 'the "License"' |
| 14 | + assert normalize_phrase('foo & bar') == 'foo & bar' |
| 15 | + |
| 16 | + def test_preserves_urls_in_angle_brackets(self): |
| 17 | + result = normalize_phrase('<http://example.com/LICENSE>') |
| 18 | + assert result == 'http://example.com/LICENSE' |
| 19 | + |
| 20 | + def test_strips_xml_tags(self): |
| 21 | + assert normalize_phrase('<name>Apache 2.0</name>') == 'Apache 2.0' |
| 22 | + assert normalize_phrase('<license>MIT</license>') == 'MIT' |
| 23 | + |
| 24 | + def test_strips_backticks(self): |
| 25 | + assert normalize_phrase('`MIT License`') == 'MIT License' |
| 26 | + |
| 27 | + def test_collapses_whitespace(self): |
| 28 | + assert normalize_phrase('GNU General\n Public License') == 'GNU General Public License' |
| 29 | + |
| 30 | + def test_strips_trailing_punct(self): |
| 31 | + assert normalize_phrase('Apache 2.0.') == 'Apache 2.0' |
| 32 | + assert normalize_phrase(',MIT,') == 'MIT' |
| 33 | + |
| 34 | + def test_empty_after_strip(self): |
| 35 | + assert normalize_phrase('<foo>') == '' |
| 36 | + assert normalize_phrase('...') == '' |
| 37 | + |
| 38 | + |
| 39 | +class TestTagTokens: |
| 40 | + |
| 41 | + def test_single_phrase(self): |
| 42 | + tokens, labels = tag_tokens('under the {{Apache License}} terms') |
| 43 | + assert tokens == ['under', 'the', 'Apache', 'License', 'terms'] |
| 44 | + assert labels == ['O', 'O', 'B-REQ', 'E-REQ', 'O'] |
| 45 | + |
| 46 | + def test_single_word_phrase(self): |
| 47 | + tokens, labels = tag_tokens('use {{MIT}} license') |
| 48 | + assert tokens == ['use', 'MIT', 'license'] |
| 49 | + assert labels == ['O', 'S-REQ', 'O'] |
| 50 | + |
| 51 | + def test_multiple_phrases(self): |
| 52 | + tokens, labels = tag_tokens('{{Apache}} and {{MIT}} stuff') |
| 53 | + assert tokens == ['Apache', 'and', 'MIT', 'stuff'] |
| 54 | + assert labels == ['S-REQ', 'O', 'S-REQ', 'O'] |
| 55 | + |
| 56 | + def test_long_phrase(self): |
| 57 | + tokens, labels = tag_tokens('{{GNU General Public License}}') |
| 58 | + assert tokens == ['GNU', 'General', 'Public', 'License'] |
| 59 | + assert labels == ['B-REQ', 'I-REQ', 'I-REQ', 'E-REQ'] |
| 60 | + |
| 61 | + def test_no_markers(self): |
| 62 | + tokens, labels = tag_tokens('released under the license') |
| 63 | + assert tokens == ['released', 'under', 'the', 'license'] |
| 64 | + assert labels == ['O', 'O', 'O', 'O'] |
| 65 | + |
| 66 | + def test_alignment(self): |
| 67 | + tokens, labels = tag_tokens('licensed under {{Apache License}} or {{MIT}}') |
| 68 | + assert len(tokens) == len(labels) |
| 69 | + |
| 70 | + def test_empty_input(self): |
| 71 | + tokens, labels = tag_tokens('') |
| 72 | + assert tokens == [] |
| 73 | + assert labels == [] |
| 74 | + |
| 75 | + def test_empty_markers_ignored(self): |
| 76 | + tokens, labels = tag_tokens('licensed under {{}} the GPL') |
| 77 | + assert tokens == ['licensed', 'under', 'the', 'GPL'] |
| 78 | + assert labels == ['O', 'O', 'O', 'O'] |
| 79 | + |
| 80 | + |
| 81 | +class TestAssignSplits: |
| 82 | + |
| 83 | + def test_light_expressions_no_leakage(self): |
| 84 | + results = [] |
| 85 | + for i in range(5): |
| 86 | + for j in range(10): |
| 87 | + results.append({'license_expression': f'license-{i}', 'identifier': f'rule_{i}_{j}.RULE'}) |
| 88 | + |
| 89 | + heavy, assignment = assign_splits(results) |
| 90 | + assert len(heavy) == 0 |
| 91 | + assert len(assignment) == 5 |
| 92 | + assert all(s in ('train', 'val', 'test') for s in assignment.values()) |
| 93 | + |
| 94 | + def test_heavy_expressions_detected(self): |
| 95 | + results = [{'license_expression': 'mit', 'identifier': f'mit_{i}.RULE'} for i in range(100)] |
| 96 | + results += [{'license_expression': 'rare-1.0', 'identifier': 'rare_1.RULE'}] |
| 97 | + |
| 98 | + heavy, assignment = assign_splits(results) |
| 99 | + assert 'mit' in heavy |
| 100 | + assert 'rare-1.0' not in heavy |
| 101 | + assert 'rare-1.0' in assignment |
0 commit comments