Skip to content

Commit

Permalink
add new test + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RainbowRivey committed Feb 11, 2025
1 parent 502aa75 commit 80eb926
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def encode_input(
map(lambda x: (x[0], x[1].resolve()), arguments)
)
if prev_label == rel.label:
# if 'collect_statistics=true' such duplicates won't be collected and are not counted in
# statistics if 'collect_statistics=true' either as 'available' or as 'skipped_same_arguments'
logger.warning(
f"doc.id={document.id}: Relation annotation `{rel.resolve()}` is duplicated. "
f"We keep only one of them."
Expand All @@ -714,16 +716,17 @@ def encode_input(
f"previous label='{prev_label}' and current label='{rel.label}'. We only keep the first "
f"occurring relation which has the label='{prev_label}'."
)
# if `keep_none`, first occurred relations are removed after _add_candidate_relations() call,
# so that none of them are re-added as 'no-relation'
# if `handle_relations_with_same_arguments="keep_none`, first occurred relations are removed
# after _add_candidate_relations() call, so that none of them are re-added as 'no-relation'
elif self.handle_relations_with_same_arguments == "keep_none":
logger.warning(
f"doc.id={document.id}: there are multiple relations with the same arguments {arguments_resolved}: "
f"previous label='{prev_label}' and current label='{rel.label}'. Both relations will be removed."
)
else:
raise ValueError(
f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `{self.handle_relations_with_same_arguments}`."
f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', "
f"but got `{self.handle_relations_with_same_arguments}`."
)
self.collect_relation("skipped_same_arguments", rel)
else:
Expand All @@ -742,7 +745,8 @@ def encode_input(
self._add_candidate_relations(
arguments2relation=arguments2relation, entities=entities, doc_id=document.id
)
# remove remaining relation duplicates
# remove remaining relation duplicates. It should be done here, after _add_candidate_relations() so that
# removed relations are not re-added with `no-relation` label.
if self.handle_relations_with_same_arguments == "keep_none":
for arguments in arguments_duplicates:
rel = arguments2relation.pop(arguments)
Expand Down
58 changes: 55 additions & 3 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def test_collate_with_add_argument_indices(batch_with_argument_indices):


@pytest.mark.parametrize(
"handle_relations_with_same_arguments", ["keep_first", "keep_none", "keep_second"]
"handle_relations_with_same_arguments", ["keep_first", "keep_none", "unknown_value"]
)
def test_encode_input_multiple_relations_for_same_arguments(
caplog, handle_relations_with_same_arguments
Expand All @@ -1070,11 +1070,11 @@ def test_encode_input_multiple_relations_for_same_arguments(
)
taskmodule.prepare([document])

if handle_relations_with_same_arguments not in ["keep_first", "keep_none"]:
if handle_relations_with_same_arguments == "unknown_value":
with pytest.raises(ValueError) as excinfo:
encodings = taskmodule.encode_input(document)
assert str(excinfo.value) == (
"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `keep_second`."
"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `unknown_value`."
)
else:
encodings = taskmodule.encode_input(document)
Expand All @@ -1097,6 +1097,8 @@ def test_encode_input_multiple_relations_for_same_arguments(
"`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep "
"only one of them."
)
# with 'keep_first', only first relation occurred is kept ('per:founded_by').
# full duplicate of 'per:founded_by' is removed and appears neither as available, nor as skipped in statistics.
assert (
caplog.messages[2] == "statistics:\n"
"| | per:founded_by | per:founder | all_relations |\n"
Expand Down Expand Up @@ -1125,6 +1127,8 @@ def test_encode_input_multiple_relations_for_same_arguments(
"`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. "
"We keep only one of them."
)
# with 'keep_none' both relations sharing same arguments are removed
# full duplicate of 'per:founded_by' is removed and appears neither as available, nor as skipped in statistics.
assert (
caplog.messages[2] == "statistics:\n"
"| | per:founded_by | per:founder | all_relations |\n"
Expand All @@ -1135,6 +1139,54 @@ def test_encode_input_multiple_relations_for_same_arguments(
assert len(encodings) == 0


@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"])
def test_encode_input_duplicated_relations(caplog, handle_relations_with_same_arguments):
taskmodule = RETextClassificationWithIndicesTaskModule(
relation_annotation="relations",
tokenizer_name_or_path="bert-base-cased",
handle_relations_with_same_arguments=handle_relations_with_same_arguments,
collect_statistics=True,
)
document = TestDocument(text="A founded B.", id="multiple_relations_for_same_arguments")
document.entities.append(LabeledSpan(start=0, end=1, label="PER"))
document.entities.append(LabeledSpan(start=10, end=11, label="PER"))
entities = document.entities
assert str(entities[0]) == "A"
assert str(entities[1]) == "B"
document.relations.extend(
[
BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
]
)
taskmodule.prepare([document])
encodings = taskmodule.encode_input(document)

with caplog.at_level(logging.INFO):
taskmodule.show_statistics()
assert len(caplog.messages) == 2
assert (
caplog.messages[0] == "doc.id=multiple_relations_for_same_arguments: Relation annotation "
"`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep "
"only one of them."
)
# equally for 'keep_first' and 'keep_last', full duplicates are not affected and do not appear in statistics, but still
# generate a warning.
assert (
caplog.messages[1] == "statistics:\n"
"| | per:founded_by |\n"
"|:----------|-----------------:|\n"
"| available | 1 |\n"
"| used | 1 |\n"
"| used % | 100 |"
)
assert len(encodings) == 1
relation = encodings[0].metadata["candidate_annotation"]
assert str(relation.head) == "A"
assert str(relation.tail) == "B"
assert relation.label == "per:founded_by"


def test_encode_input_argument_role_unknown(documents):
taskmodule = RETextClassificationWithIndicesTaskModule(
relation_annotation="relations",
Expand Down

0 comments on commit 80eb926

Please sign in to comment.