Skip to content

Commit

Permalink
Fixes metadata propagation (#99)
Browse files Browse the repository at this point in the history
* Adds metadata propagation test.

This is currently failing.

* Fixes bug with filtering dropping metadata

The copy is mistyped and then we lose that the copy is a `TokenList`, so
we also lose the metadata in the conversion.

The test required `isort` to be run again too.

* Adds pickling test and fixes metadata.

I still do not know how we get a `self` without metadata.

* Hopefully cleans up test junk.

* Trying to hit the coverage threshold.

* Correct misspelling.

---------

Co-authored-by: Emil Stenström <[email protected]>
  • Loading branch information
kylebgorman and EmilStenstrom authored Sep 19, 2024
1 parent 54e55af commit 97736cf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
12 changes: 7 additions & 5 deletions conllu/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def copy(self) -> 'TokenList':
return TokenList(tokens_copy, self.metadata, self.default_fields)

def extend(self, iterable: T.Union['TokenList', T.Iterable[Token]]) -> None:
if not hasattr(self, "metadata"):
self.metadata = Metadata()
if not isinstance(iterable, TokenList):
iterable = TokenList(iterable)

super(TokenList, self).extend(iterable)

self.metadata.update(iterable.metadata)

def _dict_to_token_and_set_defaults(self, token: T.Union[dict, Token]) -> Token:
Expand Down Expand Up @@ -187,7 +187,7 @@ def _create_tree(head_to_token_mapping: T.Dict[int, T.List[Token]], id_: int = 0
return root

def filter(self, **kwargs: T.Any) -> 'TokenList':
tokens: T.Iterable[Token] = self.copy()
tokens = self.copy()

for query, value in kwargs.items():
filtered_tokens = []
Expand All @@ -198,9 +198,9 @@ def filter(self, **kwargs: T.Any) -> 'TokenList':
if traverse_dict(token, query) == value:
filtered_tokens.append(token)

tokens = filtered_tokens
tokens[:] = filtered_tokens

return TokenList(tokens)
return tokens


_T = T.TypeVar("_T")
Expand Down Expand Up @@ -353,6 +353,8 @@ def copy(self) -> 'SentenceList':
return SentenceList(sentences_copy, self.metadata)

def extend(self, iterable: T.Union['SentenceList', T.Iterable[TokenList]]) -> None:
if not hasattr(self, "metadata"):
self.metadata = Metadata()
if not isinstance(iterable, SentenceList):
iterable = SentenceList(iterable)

Expand Down
31 changes: 31 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle
import tempfile
import unittest
from textwrap import dedent

Expand Down Expand Up @@ -252,10 +254,22 @@ def test_no_root_nodes(self):


class TestSerialize(unittest.TestCase):

def test_serialize_on_tokenlist(self):
tokenlist = TokenList([{"id": 1}])
self.assertEqual(tokenlist.serialize(), serialize(tokenlist))

def test_pickling_tokenlist(self):
tokenlist = TokenList([
{"id": 1, "form": "a", "field": "x"},
{"id": 2, "form": "dog", "field": "x"},
], metadata={"text": "a dog"})
sink = tempfile.NamedTemporaryFile("wb", suffix=".pkl", delete=False)
pickle.dump(tokenlist, sink)
sink.close()
with open(sink.name, "rb") as source:
tokenlist_copy = pickle.load(source)
self.assertEqual(tokenlist, tokenlist_copy)

class TestFilter(unittest.TestCase):
def test_basic_filtering(self):
Expand All @@ -280,6 +294,16 @@ def test_basic_filtering(self):
tokenlist
)

def test_metadata_propagates(self):
tokenlist = TokenList([
{"id": 1, "form": "a", "field": "x"},
{"id": 2, "form": "dog", "field": "x"},
], metadata={"text": "a dog"})
self.assertEqual(
tokenlist.filter().metadata,
tokenlist.metadata,
)

def test_and_filtering(self):
tokenlist = TokenList([
{"id": 1, "form": "a", "field": "x"},
Expand Down Expand Up @@ -521,6 +545,13 @@ def test_init_nonlist_raises(self):
with self.assertRaises(ParseException):
SentenceList((1, 2, 4))

def test_extend_with_no_metadata(self):
sl = SentenceList([TokenList([{"id": 1}])])
del sl.metadata
sl.extend(sl)
self.assertTrue(hasattr(sl, "metadata"))
self.assertIsNotNone(sl.metadata)

def test_equals(self):
tokenlists = [TokenList([{"id": 1}])]
self.assertEqual(SentenceList(tokenlists), SentenceList(tokenlists))
Expand Down

0 comments on commit 97736cf

Please sign in to comment.