Skip to content

Commit

Permalink
names: add orcid public data sync
Browse files Browse the repository at this point in the history
* closes #353
  • Loading branch information
jrcastro2 committed Jul 17, 2024
1 parent c038416 commit 56000b8
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 36 deletions.
24 changes: 22 additions & 2 deletions invenio_vocabularies/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def vocabularies():

def _process_vocab(config, num_samples=None):
"""Import a vocabulary."""
import time
start_time = time.time()
ds = DataStreamFactory.create(
readers_config=config["readers"],
transformers_config=config.get("transformers"),
Expand All @@ -34,7 +36,8 @@ def _process_vocab(config, num_samples=None):

success, errored, filtered = 0, 0, 0
left = num_samples or -1
for result in ds.process():
for result in ds.process(batch_size=config["batch_size"] if "batch_size" in config else 100
,write_many=config["write_many"] if "write_many" in config else False):
left = left - 1
if result.filtered:
filtered += 1
Expand All @@ -47,6 +50,20 @@ def _process_vocab(config, num_samples=None):
if left == 0:
click.secho(f"Number of samples reached {num_samples}", fg="green")
break

end_time = time.time()

elapsed_time = end_time - start_time
friendly_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
friendly_time_per_record = 0
if success:
elapsed_time_per_record = elapsed_time/success * 1000
friendly_time_per_record = time.strftime("%H:%M:%S", time.gmtime(elapsed_time_per_record))

print(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n")
with open("/tmp/elapsed_time.txt", "a") as file:
file.write(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n")

return success, errored, filtered


Expand Down Expand Up @@ -101,7 +118,10 @@ def update(vocabulary, filepath=None, origin=None):
config = vc.get_config(filepath, origin)

for w_conf in config["writers"]:
w_conf["args"]["update"] = True
if w_conf["type"] == "async":
w_conf["args"]["writer"]["args"]["update"] = True
else:
w_conf["args"]["update"] = True

success, errored, filtered = _process_vocab(config)

Expand Down
12 changes: 11 additions & 1 deletion invenio_vocabularies/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ZipReader,
)
from .datastreams.transformers import XMLTransformer
from .datastreams.writers import ServiceWriter, YamlWriter
from .datastreams.writers import AsyncWriter, ServiceWriter, YamlWriter
from .resources.resource import VocabulariesResourceConfig
from .services.service import VocabulariesServiceConfig

Expand Down Expand Up @@ -122,5 +122,15 @@
VOCABULARIES_DATASTREAM_WRITERS = {
"service": ServiceWriter,
"yaml": YamlWriter,
"async": AsyncWriter,
}
"""Data Streams writers."""


VOCABULARIES_ORCID_ACCESS_KEY="CHANGE_ME"
VOCABULARIES_ORCID_SECRET_KEY="CHANGE_ME"
VOCABULARIES_ORCID_FOLDER="/tmp/ORCID_public_data_files/"
VOCABULARIES_ORCID_SUMMARIES_BUCKET="v3.0-summaries"
VOCABULARIES_DATASTREAM_BATCH_SIZE = 100
VOCABULARIES_ORCID_SYNC_MAX_WORKERS = 32
VOCABULARIES_ORCID_SYNC_DAYS = 1
98 changes: 90 additions & 8 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,84 @@

"""Names datastreams, transformers, writers and readers."""

from invenio_access.permissions import system_identity
from invenio_records.dictutils import dict_lookup

from ...datastreams.errors import TransformerError
from ...datastreams.readers import SimpleHTTPReader
from ...datastreams.readers import SimpleHTTPReader, BaseReader
from ...datastreams.transformers import BaseTransformer
from ...datastreams.writers import ServiceWriter
import boto3
from flask import current_app
from datetime import datetime
from datetime import timedelta
import tarfile
import io
from concurrent.futures import ThreadPoolExecutor, as_completed

class OrcidDataSyncReader(BaseReader):
"""ORCiD Data Sync Reader."""

def _iter(self, fp, *args, **kwargs):
"""."""
raise NotImplementedError(
"OrcidDataSyncReader downloads one file and therefore does not iterate through items"
)

def read(self, item=None, *args, **kwargs):
"""Downloads the ORCiD lambda file and yields an in-memory binary stream of it."""

path = current_app.config["VOCABULARIES_ORCID_FOLDER"]
date_format = '%Y-%m-%d %H:%M:%S.%f'
date_format_no_millis = '%Y-%m-%d %H:%M:%S'

s3client = boto3.client('s3', aws_access_key_id=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"], aws_secret_access_key=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"])
response = s3client.get_object(Bucket='orcid-lambda-file', Key='last_modified.csv.tar')
tar_content = response['Body'].read()

last_sync = datetime.now() - timedelta(days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"])

def process_file(fileobj):
file_content = fileobj.read().decode('utf-8')
orcids = []
for line in file_content.splitlines()[1:]: # Skip the header line
elements = line.split(',')
orcid = elements[0]

last_modified_str = elements[3]
try:
last_modified_date = datetime.strptime(last_modified_str, date_format)
except ValueError:
last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis)

if last_modified_date >= last_sync:
orcids.append(orcid)
else:
break
return orcids

orcids_to_sync = []
with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f:
orcids_to_sync.extend(process_file(f))

def fetch_orcid_data(orcid_to_sync, bucket):
suffix = orcid_to_sync[-3:]
key = f'{suffix}/{orcid_to_sync}.xml'
try:
file_response = s3client.get_object(Bucket=bucket, Key=key)
return file_response['Body'].read()
except Exception as e:
# TODO: log
return None

with ThreadPoolExecutor(max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]) as executor: # TODO allow to configure max_workers / test to use asyncio
futures = [executor.submit(fetch_orcid_data, orcid, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids_to_sync]
for future in as_completed(futures):
result = future.result()
if result is not None:
yield result


class OrcidHTTPReader(SimpleHTTPReader):
Expand Down Expand Up @@ -89,6 +160,7 @@ def _entry_id(self, entry):

VOCABULARIES_DATASTREAM_READERS = {
"orcid-http": OrcidHTTPReader,
"orcid-data-sync": OrcidDataSyncReader,
}


Expand All @@ -107,22 +179,32 @@ def _entry_id(self, entry):
DATASTREAM_CONFIG = {
"readers": [
{
"type": "tar",
"args": {
"regex": "\\.xml$",
},
"type": "orcid-data-sync",
},
{"type": "xml"},
],
"transformers": [{"type": "orcid"}],
# "writers": [
# {
# "type": "names-service",
# "args": {
# "identity": system_identity,
# },
# }
# ],
"writers": [
{
"type": "names-service",
"type": "async",
"args": {
"identity": system_identity,
"writer":{
"type": "names-service",
"args": {},
}
},
}
],
"batch_size": 1000, # TODO: current_app.config["VOCABULARIES_DATASTREAM_BATCH_SIZE"],
"write_many": True,
}
"""ORCiD Data Stream configuration.
Expand Down
64 changes: 47 additions & 17 deletions invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@

from .errors import ReaderError, TransformerError, WriterError


class StreamEntry:
"""Object to encapsulate streams processing."""

def __init__(self, entry, errors=None):
"""Constructor."""
self.entry = entry
self.filtered = False
self.errors = errors or []
def __init__(self, entry, errors=None, op_type=None):
"""Constructor for the StreamEntry class.
Args:
entry (object): The entry object, usually a record dict.
errors (list, optional): List of errors. Defaults to None.
op_type (str, optional): The operation type. Defaults to None.
"""
self.entry = entry
self.filtered = False
self.errors = errors or []
self.op_type = op_type

class DataStream:
"""Data stream."""
Expand All @@ -38,16 +43,10 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs):
def filter(self, stream_entry, *args, **kwargs):
"""Checks if an stream_entry should be filtered out (skipped)."""
return False

def process(self, *args, **kwargs):
"""Iterates over the entries.
Uses the reader to get the raw entries and transforms them.
It will iterate over the `StreamEntry` objects returned by
the reader, apply the transformations and yield the result of
writing it.
"""
for stream_entry in self.read():

def process_batch(self, batch, write_many=False):
transformed_entries = []
for stream_entry in batch:
if stream_entry.errors:
yield stream_entry # reading errors
else:
Expand All @@ -58,7 +57,33 @@ def process(self, *args, **kwargs):
transformed_entry.filtered = True
yield transformed_entry
else:
yield self.write(transformed_entry)
transformed_entries.append(transformed_entry)
if transformed_entries:
if write_many:
print(f"write_many {len(transformed_entries)} entries.")
yield from self.batch_write(transformed_entries)
else:
print(f"write {len(transformed_entries)} entries.")
yield from (self.write(entry) for entry in transformed_entries)

def process(self, batch_size=100, write_many=False, *args, **kwargs):
"""Iterates over the entries.
Uses the reader to get the raw entries and transforms them.
It will iterate over the `StreamEntry` objects returned by
the reader, apply the transformations and yield the result of
writing it.
"""
batch = []
for stream_entry in self.read():
batch.append(stream_entry)
if len(batch) >= batch_size:
yield from self.process_batch(batch, write_many=write_many)
batch = []

# Process any remaining entries in the last batch
if batch:
yield from self.process_batch(batch, write_many=write_many)

def read(self):
"""Recursively read the entries."""
Expand Down Expand Up @@ -106,6 +131,11 @@ def write(self, stream_entry, *args, **kwargs):
stream_entry.errors.append(f"{writer.__class__.__name__}: {str(err)}")

return stream_entry

def batch_write(self, stream_entries, *args, **kwargs):
"""Apply the transformations to an stream_entry. Errors are handler in the service layer."""
for writer in self._writers:
yield from writer.write_many(stream_entries)

def total(self, *args, **kwargs):
"""The total of entries obtained from the origin."""
Expand Down
7 changes: 4 additions & 3 deletions invenio_vocabularies/datastreams/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

import requests
import yaml
from lxml.html import parse as html_parse
from lxml.html import fromstring

from .errors import ReaderError
from .xml import etree_to_dict



class BaseReader(ABC):
"""Base reader."""

Expand Down Expand Up @@ -219,8 +220,8 @@ class XMLReader(BaseReader):
def _iter(self, fp, *args, **kwargs):
"""Read and parse an XML file to dict."""
# NOTE: We parse HTML, to skip XML validation and strip XML namespaces
xml_tree = html_parse(fp).getroot()
record = etree_to_dict(xml_tree)["html"]["body"].get("record")
xml_tree = fromstring(fp)
record = etree_to_dict(xml_tree).get("record")

if not record:
raise ReaderError(f"Record not found in XML entry.")
Expand Down
15 changes: 13 additions & 2 deletions invenio_vocabularies/datastreams/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@


@shared_task(ignore_result=True)
def write_entry(writer, entry):
def write_entry(writer_config, entry):
"""Write an entry.
:param writer: writer configuration as accepted by the WriterFactory.
:param entry: dictionary, StreamEntry is not serializable.
"""
writer = WriterFactory.create(config=writer)
writer = WriterFactory.create(config=writer_config)
writer.write(StreamEntry(entry))

@shared_task(ignore_result=True)
def write_many_entry(writer_config, entries):
"""Write many entries.
:param writer: writer configuration as accepted by the WriterFactory.
:param entry: lisf ot dictionaries, StreamEntry is not serializable.
"""
writer = WriterFactory.create(config=writer_config)
stream_entries = [StreamEntry(entry) for entry in entries]
writer.write_many(stream_entries)
Loading

0 comments on commit 56000b8

Please sign in to comment.