diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py index cd46f2ccb..20b0fc465 100644 --- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py +++ b/src/pie_modules/taskmodules/re_text_classification_with_indices.py @@ -702,6 +702,8 @@ def encode_input( map(lambda x: (x[0], x[1].resolve()), arguments) ) if prev_label == rel.label: + # if 'collect_statistics=true' such duplicates won't be collected and are not counted in + # statistics if 'collect_statistics=true' either as 'available' or as 'skipped_same_arguments' logger.warning( f"doc.id={document.id}: Relation annotation `{rel.resolve()}` is duplicated. " f"We keep only one of them." @@ -714,8 +716,8 @@ def encode_input( f"previous label='{prev_label}' and current label='{rel.label}'. We only keep the first " f"occurring relation which has the label='{prev_label}'." ) - # if `keep_none`, first occurred relations are removed after _add_candidate_relations() call, - # so that none of them are re-added as 'no-relation' + # if `handle_relations_with_same_arguments="keep_none`, first occurred relations are removed + # after _add_candidate_relations() call, so that none of them are re-added as 'no-relation' elif self.handle_relations_with_same_arguments == "keep_none": logger.warning( f"doc.id={document.id}: there are multiple relations with the same arguments {arguments_resolved}: " @@ -723,7 +725,8 @@ def encode_input( ) else: raise ValueError( - f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `{self.handle_relations_with_same_arguments}`." + f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', " + f"but got `{self.handle_relations_with_same_arguments}`." ) self.collect_relation("skipped_same_arguments", rel) else: @@ -742,7 +745,8 @@ def encode_input( self._add_candidate_relations( arguments2relation=arguments2relation, entities=entities, doc_id=document.id ) - # remove remaining relation duplicates + # remove remaining relation duplicates. It should be done here, after _add_candidate_relations() so that + # removed relations are not re-added with `no-relation` label. if self.handle_relations_with_same_arguments == "keep_none": for arguments in arguments_duplicates: rel = arguments2relation.pop(arguments) diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index d118f0509..bd4a60bb6 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1044,7 +1044,7 @@ def test_collate_with_add_argument_indices(batch_with_argument_indices): @pytest.mark.parametrize( - "handle_relations_with_same_arguments", ["keep_first", "keep_none", "keep_second"] + "handle_relations_with_same_arguments", ["keep_first", "keep_none", "unknown_value"] ) def test_encode_input_multiple_relations_for_same_arguments( caplog, handle_relations_with_same_arguments @@ -1070,11 +1070,11 @@ def test_encode_input_multiple_relations_for_same_arguments( ) taskmodule.prepare([document]) - if handle_relations_with_same_arguments not in ["keep_first", "keep_none"]: + if handle_relations_with_same_arguments == "unknown_value": with pytest.raises(ValueError) as excinfo: encodings = taskmodule.encode_input(document) assert str(excinfo.value) == ( - "'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `keep_second`." + "'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `unknown_value`." ) else: encodings = taskmodule.encode_input(document) @@ -1097,6 +1097,8 @@ def test_encode_input_multiple_relations_for_same_arguments( "`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep " "only one of them." ) + # with 'keep_first', only first relation occurred is kept ('per:founded_by'). + # full duplicate of 'per:founded_by' is removed and appears neither as available, nor as skipped in statistics. assert ( caplog.messages[2] == "statistics:\n" "| | per:founded_by | per:founder | all_relations |\n" @@ -1125,6 +1127,8 @@ def test_encode_input_multiple_relations_for_same_arguments( "`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. " "We keep only one of them." ) + # with 'keep_none' both relations sharing same arguments are removed + # full duplicate of 'per:founded_by' is removed and appears neither as available, nor as skipped in statistics. assert ( caplog.messages[2] == "statistics:\n" "| | per:founded_by | per:founder | all_relations |\n" @@ -1135,6 +1139,54 @@ def test_encode_input_multiple_relations_for_same_arguments( assert len(encodings) == 0 +@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"]) +def test_encode_input_duplicated_relations(caplog, handle_relations_with_same_arguments): + taskmodule = RETextClassificationWithIndicesTaskModule( + relation_annotation="relations", + tokenizer_name_or_path="bert-base-cased", + handle_relations_with_same_arguments=handle_relations_with_same_arguments, + collect_statistics=True, + ) + document = TestDocument(text="A founded B.", id="multiple_relations_for_same_arguments") + document.entities.append(LabeledSpan(start=0, end=1, label="PER")) + document.entities.append(LabeledSpan(start=10, end=11, label="PER")) + entities = document.entities + assert str(entities[0]) == "A" + assert str(entities[1]) == "B" + document.relations.extend( + [ + BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), + BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), + ] + ) + taskmodule.prepare([document]) + encodings = taskmodule.encode_input(document) + + with caplog.at_level(logging.INFO): + taskmodule.show_statistics() + assert len(caplog.messages) == 2 + assert ( + caplog.messages[0] == "doc.id=multiple_relations_for_same_arguments: Relation annotation " + "`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep " + "only one of them." + ) + # equally for 'keep_first' and 'keep_last', full duplicates are not affected and do not appear in statistics, but still + # generate a warning. + assert ( + caplog.messages[1] == "statistics:\n" + "| | per:founded_by |\n" + "|:----------|-----------------:|\n" + "| available | 1 |\n" + "| used | 1 |\n" + "| used % | 100 |" + ) + assert len(encodings) == 1 + relation = encodings[0].metadata["candidate_annotation"] + assert str(relation.head) == "A" + assert str(relation.tail) == "B" + assert relation.label == "per:founded_by" + + def test_encode_input_argument_role_unknown(documents): taskmodule = RETextClassificationWithIndicesTaskModule( relation_annotation="relations",