Skip to content

Commit

Permalink
datastreams: consolidate the use of StreamEntry
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo Panero committed Jan 13, 2022
1 parent 075d0cc commit ccf09c8
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 66 deletions.
9 changes: 5 additions & 4 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from invenio_access.permissions import system_identity
from invenio_records.dictutils import dict_lookup

from ...datastreams import StreamEntry
from ...datastreams.errors import TransformerError
from ...datastreams.transformers import XMLTransformer


class OrcidXMLTransformer(XMLTransformer):
"""ORCiD XML Transfomer."""

def apply(self, entry, **kwargs):
"""Applies the transformation to the entry."""
xml_tree = self._xml_to_etree(entry)
def apply(self, stream_entry, **kwargs):
"""Applies the transformation to the stream entry."""
xml_tree = self._xml_to_etree(stream_entry.entry)
researcher = self._etree_to_dict(xml_tree)
record = researcher["html"]["body"].get("record")

Expand Down Expand Up @@ -60,7 +61,7 @@ def apply(self, entry, **kwargs):
except Exception:
pass

return entry
return StreamEntry(entry)


VOCABULARIES_DATASTREAM_TRANSFORMERS = {
Expand Down
3 changes: 2 additions & 1 deletion invenio_vocabularies/datastreams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

"""Datastreams module."""

from .datastreams import BaseDataStream
from .datastreams import BaseDataStream, StreamEntry
from .factories import DataStreamFactory

__all__ = (
"BaseDataStream",
"DataStreamFactory",
"StreamEntry",
)
50 changes: 27 additions & 23 deletions invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .errors import TransformerError, WriterError


class StreamResult:
"""Result object for streams processing."""
class StreamEntry:
"""Object to encapsulate streams processing."""

def __init__(self, entry, errors=None):
"""Constructor."""
Expand All @@ -34,47 +34,51 @@ def __init__(self, reader, writers, transformers=None, *args, **kwargs):
self._transformers = transformers
self._writers = writers

def filter(self, entry, *args, **kwargs):
"""Checks if an entry should be filtered out (skipped)."""
def filter(self, stream_entry, *args, **kwargs):
"""Checks if an stream_entry should be filtered out (skipped)."""
return False

def process(self, *args, **kwargs):
"""Iterates over the entries.
Uses the reader to get the raw entries and transforms them.
It will iterate over the `StreamEntry` objects returned by
the reader, apply the transformations and yield the result of
writing it.
"""
for entry in self._reader.read():
result = self.transform(entry)
if result.errors:
yield StreamResult(entry=entry, errors=result.errors)
elif not self.filter(result.entry):
yield self.write(result.entry)

def transform(self, entry, *args, **kwargs):
"""Apply the transformations to an entry."""
for stream_entry in self._reader.read():
transformed_entry = self.transform(stream_entry)
if transformed_entry.errors:
yield transformed_entry
elif not self.filter(transformed_entry):
yield self.write(transformed_entry)

def transform(self, stream_entry, *args, **kwargs):
"""Apply the transformations to an stream_entry."""
for transformer in self._transformers:
try:
entry = transformer.apply(entry)
stream_entry = transformer.apply(stream_entry)
except TransformerError as err:
return StreamResult(
entry,
return StreamEntry(
stream_entry.entry,
# FIXME: __ is ugly, add name cls attr?
[f"{transformer.__class__.__name__}: {str(err)}"]
)

return StreamResult(entry)
return stream_entry

def write(self, entry, *args, **kwargs):
"""Apply the transformations to an entry."""
errors = []
def write(self, stream_entry, *args, **kwargs):
"""Apply the transformations to an stream_entry."""
for writer in self._writers:
try:
writer.write(entry)
writer.write(stream_entry)
except WriterError as err:
# FIXME: __ is ugly, add name cls attr?
errors.append(f"{writer.__class__.__name__}: {str(err)}")
stream_entry.errors.append(
f"{writer.__class__.__name__}: {str(err)}"
)

return StreamResult(entry, errors)
return stream_entry

def total(self, *args, **kwargs):
"""The total of entries obtained from the origin."""
Expand Down
11 changes: 8 additions & 3 deletions invenio_vocabularies/datastreams/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import yaml

from .datastreams import StreamEntry


class BaseReader:
"""Base reader."""
Expand All @@ -22,7 +24,10 @@ def __init__(self, origin, *args, **kwargs):
self._origin = origin

def read(self, *args, **kwargs):
"""Reads the content from the origin."""
"""Reads the content from the origin.
Yields `StreamEntry` objects.
"""
pass


Expand All @@ -34,7 +39,7 @@ def read(self):
with open(self._origin) as f:
data = yaml.safe_load(f) or []
for entry in data:
yield entry
yield StreamEntry(entry)


class TarReader(BaseReader):
Expand All @@ -53,4 +58,4 @@ def read(self):
match = not self._regex or self._regex.search(member.name)
if member.isfile() and match:
content = archive.extractfile(member).read()
yield content
yield StreamEntry(entry=content)
6 changes: 3 additions & 3 deletions invenio_vocabularies/datastreams/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def __init__(self, *args, **kwargs):
"""Constructor."""
pass

def apply(self, entry, *args, **kwargs):
def apply(self, stream_entry, *args, **kwargs):
"""Applies the transformation to the entry.
:returns: The transformed entry, this allow them to be chained
raises TransformerError in case of errors.
:returns: A StreamEntry. The transformed entry.
Raises TransformerError in case of errors.
"""
pass

Expand Down
29 changes: 12 additions & 17 deletions invenio_vocabularies/datastreams/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from invenio_records_resources.proxies import current_service_registry
from marshmallow import ValidationError

from .datastreams import StreamResult
from .errors import WriterError


Expand All @@ -26,10 +25,12 @@ def __init__(self, *args, **kwargs):
"""Constructor."""
pass

def write(self, entry, *args, **kwargs):
"""Writes the input entry to the target output.
def write(self, stream_entry, *args, **kwargs):
"""Writes the input stream entry to the target output.
:returns: A StreamEntry. The result of writing the entry.
Raises WriterException in case of errors.
Raises WriterException in case of errors.
"""
pass

Expand All @@ -52,17 +53,12 @@ def __init__(self, service_or_name, identity, *args, **kwargs):

super().__init__(*args, **kwargs)

def write(self, entry, *args, **kwargs):
def write(self, stream_entry, *args, **kwargs):
"""Writes the input entry using a given service."""
try:
result = self._service.create(self._identity, entry)
result = self._service.create(self._identity, stream_entry.entry)
except ValidationError as err:
result = StreamResult(
entry=entry,
errors=[{"ValidationError": err.messages}]
)
if result.errors:
raise WriterError(result.errors)
raise WriterError([{"ValidationError": err.messages}])

return result

Expand All @@ -79,12 +75,11 @@ def __init__(self, filepath, *args, **kwargs):

super().__init__(*args, **kwargs)

def write(self, entry, *args, **kwargs):
"""Writes the input entry using a given service."""
def write(self, stream_entry, *args, **kwargs):
"""Writes the input stream entry using a given service."""
with open(self._filepath, 'a') as file:
# made into array for safer append
# will always read array (good for reader)
yaml.safe_dump([entry], file)
result = StreamResult(entry=entry)
yaml.safe_dump([stream_entry.entry], file)

return result
return stream_entry
7 changes: 4 additions & 3 deletions tests/contrib/names/test_names_datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from invenio_vocabularies.contrib.names.datastreams import OrcidXMLTransformer
from invenio_vocabularies.datastreams import StreamEntry


@pytest.fixture(scope='module')
Expand All @@ -32,7 +33,7 @@ def expected_from_xml():
@pytest.fixture(scope='module')
def xml_entry():
# simplified version of an XML file of the ORCiD dump
return bytes(
return StreamEntry(bytes(
'<?xml version="1.0" encoding="UTF-8" standalone="yes"?>\n'
'<record:record path="/0000-0001-8135-3489">\n'
' <common:orcid-identifier>\n'
Expand Down Expand Up @@ -60,9 +61,9 @@ def xml_entry():
' </activities:activities-summary>\n'
'</record:record>\n',
encoding="raw_unicode_escape"
)
))


def test_orcid_xml_transformer(xml_entry, expected_from_xml):
transformer = OrcidXMLTransformer()
assert expected_from_xml == transformer.apply(xml_entry)
assert expected_from_xml == transformer.apply(xml_entry).entry
13 changes: 7 additions & 6 deletions tests/datastreams/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest

from invenio_vocabularies.datastreams import StreamEntry
from invenio_vocabularies.datastreams.errors import TransformerError, \
WriterError
from invenio_vocabularies.datastreams.readers import BaseReader
Expand All @@ -27,18 +28,18 @@ class TestReader(BaseReader):
def read(self, *args, **kwargs):
"""Yields the values in the origin."""
for value in self._origin:
yield value
yield StreamEntry(value)


class TestTransformer(BaseTransformer):
"""Test transformer."""

def apply(self, entry, *args, **kwargs):
def apply(self, stream_entry, *args, **kwargs):
"""Sums up one to the value."""
if entry < 0:
if stream_entry.entry < 0:
raise TransformerError("Value cannot be negative")

return entry + 1
return StreamEntry(stream_entry.entry + 1)


class TestWriter(BaseWriter):
Expand All @@ -53,9 +54,9 @@ def __init__(self, fail_on):
super().__init__()
self.fail_on = fail_on

def write(self, entry):
def write(self, stream_entry):
"""Return the entry."""
if entry == self.fail_on:
if stream_entry.entry == self.fail_on:
raise WriterError(f"{self.fail_on} value found.")


Expand Down
8 changes: 4 additions & 4 deletions tests/datastreams/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def yaml_file(expected_from_yaml):
def test_yaml_reader(yaml_file, expected_from_yaml):
reader = YamlReader(yaml_file)

for idx, entry in enumerate(reader.read()):
assert entry == expected_from_yaml[idx]
for idx, stream_entry in enumerate(reader.read()):
assert stream_entry.entry == expected_from_yaml[idx]


@pytest.fixture(scope='module')
Expand Down Expand Up @@ -85,8 +85,8 @@ def test_tar_reader(tar_file, expected_from_tar):
reader = TarReader(tar_file, regex=".yaml$")

total = 0
for entry in reader.read():
assert yaml.safe_load(entry) == expected_from_tar
for stream_entry in reader.read():
assert yaml.safe_load(stream_entry.entry) == expected_from_tar
total += 1

assert total == 2 # ignored the `.other` file
5 changes: 3 additions & 2 deletions tests/datastreams/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

import yaml

from invenio_vocabularies.datastreams import StreamEntry
from invenio_vocabularies.datastreams.writers import ServiceWriter, YamlWriter


def test_service_writer(lang_type, lang_data, service, identity):
writer = ServiceWriter(service, identity)
lang = writer.write(entry=lang_data)
lang = writer.write(stream_entry=StreamEntry(lang_data))
record = service.read(identity, ("languages", lang.id))
record = record.to_dict()

Expand All @@ -33,7 +34,7 @@ def test_yaml_writer():

writer = YamlWriter(filepath=filepath)
for output in test_output:
assert not writer.write(entry=output).errors
assert not writer.write(stream_entry=StreamEntry(output)).errors

with open(filepath) as file:
assert yaml.safe_load(file) == test_output
Expand Down

0 comments on commit ccf09c8

Please sign in to comment.