diff --git a/conllu/models.py b/conllu/models.py index 25e5b3b..1b73591 100644 --- a/conllu/models.py +++ b/conllu/models.py @@ -95,11 +95,11 @@ def copy(self) -> 'TokenList': return TokenList(tokens_copy, self.metadata, self.default_fields) def extend(self, iterable: T.Union['TokenList', T.Iterable[Token]]) -> None: + if not hasattr(self, "metadata"): + self.metadata = Metadata() if not isinstance(iterable, TokenList): iterable = TokenList(iterable) - super(TokenList, self).extend(iterable) - self.metadata.update(iterable.metadata) def _dict_to_token_and_set_defaults(self, token: T.Union[dict, Token]) -> Token: @@ -187,7 +187,7 @@ def _create_tree(head_to_token_mapping: T.Dict[int, T.List[Token]], id_: int = 0 return root def filter(self, **kwargs: T.Any) -> 'TokenList': - tokens: T.Iterable[Token] = self.copy() + tokens = self.copy() for query, value in kwargs.items(): filtered_tokens = [] @@ -198,9 +198,9 @@ def filter(self, **kwargs: T.Any) -> 'TokenList': if traverse_dict(token, query) == value: filtered_tokens.append(token) - tokens = filtered_tokens + tokens[:] = filtered_tokens - return TokenList(tokens) + return tokens _T = T.TypeVar("_T") @@ -353,6 +353,8 @@ def copy(self) -> 'SentenceList': return SentenceList(sentences_copy, self.metadata) def extend(self, iterable: T.Union['SentenceList', T.Iterable[TokenList]]) -> None: + if not hasattr(self, "metadata"): + self.metadata = Metadata() if not isinstance(iterable, SentenceList): iterable = SentenceList(iterable) diff --git a/tests/test_models.py b/tests/test_models.py index 838f6d9..f132dad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,5 @@ +import pickle +import tempfile import unittest from textwrap import dedent @@ -252,10 +254,22 @@ def test_no_root_nodes(self): class TestSerialize(unittest.TestCase): + def test_serialize_on_tokenlist(self): tokenlist = TokenList([{"id": 1}]) self.assertEqual(tokenlist.serialize(), serialize(tokenlist)) + def test_pickling_tokenlist(self): + tokenlist = TokenList([ + {"id": 1, "form": "a", "field": "x"}, + {"id": 2, "form": "dog", "field": "x"}, + ], metadata={"text": "a dog"}) + sink = tempfile.NamedTemporaryFile("wb", suffix=".pkl", delete=False) + pickle.dump(tokenlist, sink) + sink.close() + with open(sink.name, "rb") as source: + tokenlist_copy = pickle.load(source) + self.assertEqual(tokenlist, tokenlist_copy) class TestFilter(unittest.TestCase): def test_basic_filtering(self): @@ -280,6 +294,16 @@ def test_basic_filtering(self): tokenlist ) + def test_metadata_propagates(self): + tokenlist = TokenList([ + {"id": 1, "form": "a", "field": "x"}, + {"id": 2, "form": "dog", "field": "x"}, + ], metadata={"text": "a dog"}) + self.assertEqual( + tokenlist.filter().metadata, + tokenlist.metadata, + ) + def test_and_filtering(self): tokenlist = TokenList([ {"id": 1, "form": "a", "field": "x"}, @@ -521,6 +545,13 @@ def test_init_nonlist_raises(self): with self.assertRaises(ParseException): SentenceList((1, 2, 4)) + def test_extend_with_no_metadata(self): + sl = SentenceList([TokenList([{"id": 1}])]) + del sl.metadata + sl.extend(sl) + self.assertTrue(hasattr(sl, "metadata")) + self.assertIsNotNone(sl.metadata) + def test_equals(self): tokenlists = [TokenList([{"id": 1}])] self.assertEqual(SentenceList(tokenlists), SentenceList(tokenlists))