Skip to content

Commit

Permalink
Add a test case for LabelAlignment
Browse files Browse the repository at this point in the history
  • Loading branch information
yasufumy committed Sep 23, 2023
1 parent 4409611 commit 9508735
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,29 @@ def test_labels_unchanged_after_encoding_and_decoding(
assert labels == label_set.decode(label_set.encode_to_tag_indices(labels))


@pytest.fixture()
def alignment() -> LabelAlignment:
@st.composite
def target_label(draw: st.DrawFn) -> SequenceLabel:
size = 9
num_tags = draw(st.integers(min_value=1, max_value=5))
tags = []
last = 1
for _ in range(num_tags):
if last >= size:
break
start = draw(st.integers(min_value=last, max_value=size - 1))
end = draw(st.integers(min_value=start + 1, max_value=size))
label = draw(st.sampled_from(["ORG", "LOC", "PER", "MISC"]))
# NOTE: mypy cannot infer a type of the dictionary below.
tags.append(cast(TagDict, {"start": start, "end": end, "label": label}))
last = end + 1

return SequenceLabel.from_dict(tags=tags, size=size + 1, base=Base.TARGET)


@given(label=target_label())
def test_labels_unchanged_after_alignment(label: SequenceLabel) -> None:
# Tokenized by RoBERTa
return LabelAlignment(
alignment = LabelAlignment(
(
None,
Span(0, 3),
Expand Down Expand Up @@ -93,6 +112,10 @@ def alignment() -> LabelAlignment:
),
)

assert label == alignment.align_with_target(
label=alignment.align_with_source(label=label)
)


def test_tags_define_in_truncated_part_ignored() -> None:
truncated_alignment = LabelAlignment(
Expand Down

0 comments on commit 9508735

Please sign in to comment.