diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index fab872f002d..5a3db37190c 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -120,6 +120,10 @@ cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil: return False +cdef bint _next_is_sent_start(const StateC* state) nogil: + return state.B(1) != -1 and state.B_(1).sent_start == 1 + + cdef class BiluoPushDown(TransitionSystem): def __init__(self, *args, **kwargs): TransitionSystem.__init__(self, *args, **kwargs) @@ -388,7 +392,7 @@ cdef class Begin: elif st.B_(1).ent_iob == 3: # If the next word is B, we can't B now return False - elif st.B_(1).sent_start == 1: + elif _next_is_sent_start(st): # Don't allow entities to extend across sentence boundaries return False # Don't allow entities to start on whitespace @@ -466,7 +470,7 @@ cdef class In: # Otherwise, force acceptance, even if we're across a sentence # boundary or the token is whitespace. return True - elif st.B(1) != -1 and st.B_(1).sent_start == 1: + elif _next_is_sent_start(st): # Don't allow entities to extend across sentence boundaries return False else: @@ -558,8 +562,9 @@ cdef class Last: # L, Gold B --> True pass elif g_act == IN: - # L, Gold I --> True iff this entity sunk - cost += not _entity_is_sunk(s, gold.ner) + # L, Gold I --> True iff this entity sunk or there is sentence + # break after the next buffer token. + cost += not (_entity_is_sunk(s, gold.ner) or _next_is_sent_start(s)) elif g_act == LAST: # L, Gold L --> True pass @@ -674,8 +679,9 @@ cdef class Out: if g_act == MISSING: pass elif g_act == BEGIN: - # O, Gold B --> False - cost += 1 + # O, Gold B --> False, unless there is a sentence break after + # the next buffer token. + cost += not _next_is_sent_start(s) elif g_act == IN: # O, Gold I --> True pass diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 030182a6355..eb3e67740d9 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -816,6 +816,31 @@ def test_ner_warns_no_lookups(caplog): assert "W033" not in caplog.text +def test_train_sent_split_in_entity(): + # Check that we can train on inputs when entities are sentence-split + # by an annotating component. + nlp = English() + ner = nlp.add_pipe("ner", config={"update_with_oracle_cut_size": 3}) + + eg = Example.from_dict( + nlp.make_doc("I like the Kinesis Advantage2 LF very much."), + {"entities": [(11, 32, "MISC")]}, + ) + + # Go bezerk, put a boundary on every combination of tokens. + train_examples = [] + for i in range(1, len(eg.predicted)): + for j in range(1, len(eg.predicted)): + eg_ij = eg.copy() + eg_ij.predicted[i].is_sent_start = True + eg_ij.predicted[j].is_sent_start = True + train_examples.append(eg_ij) + + ner.add_label("MISC") + nlp.initialize() + nlp.update(train_examples, sgd=False, annotates=[]) + + @Language.factory("blocker") class BlockerComponent1: def __init__(self, nlp, start, end, name="my_blocker"):