Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datastream: optimize memory usage on ORCID sync #439

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
Changes
=======

Version v6.8.0 (released 2024-12-09)

- names: extract affiliation identifiers from employments
- names: optimize memory usage on ORCID sync
- subjects: improve search with CompositeSuggestQueryParser
- subjects: added datastream for bodc

Version v6.7.0 (released 2024-11-27)

- contrib: improve search accuracy for names, funders, affiliations
Expand Down
2 changes: 1 addition & 1 deletion invenio_vocabularies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@

from .ext import InvenioVocabularies

__version__ = "6.7.0"
__version__ = "6.8.0"

__all__ = ("__version__", "InvenioVocabularies")
2 changes: 2 additions & 0 deletions invenio_vocabularies/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def _process_vocab(config, num_samples=None):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
slint marked this conversation as resolved.
Show resolved Hide resolved
write_many=config.get("write_many", False),
)

success, errored, filtered = 0, 0, 0
Expand Down
86 changes: 57 additions & 29 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tarfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import timedelta
from itertools import islice
from pathlib import Path

import arrow
Expand Down Expand Up @@ -48,6 +49,8 @@ def _fetch_orcid_data(self, orcid_to_sync, bucket):
suffix = orcid_to_sync[-3:]
key = f"{suffix}/{orcid_to_sync}.xml"
try:
# Potential improvement: use the a XML jax parser to avoid loading the whole file in memory
# and choose the sections we need to read (probably the summary)
return self.s3_client.read_file(f"s3://{bucket}/{key}")
except Exception:
current_app.logger.exception("Failed to fetch ORCiD record.")
Expand All @@ -67,42 +70,54 @@ def _process_lambda_file(self, fileobj):
if self.since:
time_shift = self.since
last_sync = arrow.now() - timedelta(**time_shift)

file_content = fileobj.read().decode("utf-8")

csv_reader = csv.DictReader(file_content.splitlines())

for row in csv_reader: # Skip the header line
orcid = row["orcid"]

# Lambda file is ordered by last modified date
last_modified_str = row["last_modified"]
try:
last_modified_date = arrow.get(last_modified_str, date_format)
except arrow.parser.ParserError:
last_modified_date = arrow.get(last_modified_str, date_format_no_millis)

if last_modified_date < last_sync:
break
yield orcid
try:
content = io.TextIOWrapper(fileobj, encoding="utf-8")
csv_reader = csv.DictReader(content)

for row in csv_reader: # Skip the header line
orcid = row["orcid"]

# Lambda file is ordered by last modified date
last_modified_str = row["last_modified"]
try:
last_modified_date = arrow.get(last_modified_str, date_format)
except arrow.parser.ParserError:
last_modified_date = arrow.get(
last_modified_str, date_format_no_millis
)

if last_modified_date < last_sync:
break
yield orcid
finally:
fileobj.close()
slint marked this conversation as resolved.
Show resolved Hide resolved

def _iter(self, orcids):
"""Iterates over the ORCiD records yielding each one."""
with ThreadPoolExecutor(
max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]
) as executor:
futures = [
executor.submit(
# futures is a dictionary where the key is the ORCID value and the item is the Future object
futures = {
orcid: executor.submit(
self._fetch_orcid_data,
orcid,
current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"],
)
for orcid in orcids
]
for future in as_completed(futures):
result = future.result()
if result is not None:
yield result
}

for orcid in list(futures.keys()):
try:
result = futures[orcid].result()
if result:
yield result
finally:
# Explicitly release memory, as we don't need the future anymore.
# This is mostly required because as long as we keep a reference to the future
# (in the above futures dict), the garbage collector won't collect it
# and it will keep the memory allocated.
del futures[orcid]

def read(self, item=None, *args, **kwargs):
"""Streams the ORCiD lambda file, process it to get the ORCiDS to sync and yields it's data."""
Expand All @@ -111,18 +126,31 @@ def read(self, item=None, *args, **kwargs):
"s3://orcid-lambda-file/last_modified.csv.tar"
)

orcids_to_sync = []
# Opens tar file and process it
with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar:
# Iterate over each member (file or directory) in the tar file
for member in tar.getmembers():
# Extract the file
extracted_file = tar.extractfile(member)
if extracted_file:
current_app.logger.info(f"[ORCID Reader] Processing lambda file...")
# Process the file and get the ORCiDs to sync
orcids_to_sync.extend(self._process_lambda_file(extracted_file))

yield from self._iter(orcids_to_sync)
orcids_to_sync = set(self._process_lambda_file(extracted_file))

# Close the file explicitly after processing
extracted_file.close()

# Process ORCIDs in smaller batches
for orcid_batch in self._chunked_iter(
orcids_to_sync, batch_size=100
):
yield from self._iter(orcid_batch)

def _chunked_iter(self, iterable, batch_size):
"""Yield successive chunks of a given size."""
it = iter(iterable)
while chunk := list(islice(it, batch_size)):
yield chunk


class OrcidHTTPReader(SimpleHTTPReader):
Expand Down
25 changes: 18 additions & 7 deletions invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@ def log_errors(self, logger=None):
class DataStream:
"""Data stream."""

def __init__(self, readers, writers, transformers=None, *args, **kwargs):
def __init__(
self,
readers,
writers,
transformers=None,
batch_size=100,
write_many=False,
*args,
**kwargs,
):
"""Constructor.

:param readers: an ordered list of readers.
Expand All @@ -58,12 +67,14 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs):
self._readers = readers
self._transformers = transformers
self._writers = writers
self.batch_size = batch_size
self.write_many = write_many

def filter(self, stream_entry, *args, **kwargs):
"""Checks if an stream_entry should be filtered out (skipped)."""
return False

def process_batch(self, batch, write_many=False):
def process_batch(self, batch):
"""Process a batch of entries."""
transformed_entries = []
for stream_entry in batch:
Expand All @@ -79,12 +90,12 @@ def process_batch(self, batch, write_many=False):
else:
transformed_entries.append(transformed_entry)
if transformed_entries:
if write_many:
if self.write_many:
yield from self.batch_write(transformed_entries)
else:
yield from (self.write(entry) for entry in transformed_entries)

def process(self, batch_size=100, write_many=False, *args, **kwargs):
def process(self, *args, **kwargs):
"""Iterates over the entries.

Uses the reader to get the raw entries and transforms them.
Expand All @@ -95,13 +106,13 @@ def process(self, batch_size=100, write_many=False, *args, **kwargs):
batch = []
for stream_entry in self.read():
batch.append(stream_entry)
if len(batch) >= batch_size:
yield from self.process_batch(batch, write_many=write_many)
if len(batch) >= self.batch_size:
yield from self.process_batch(batch)
batch = []

# Process any remaining entries in the last batch
if batch:
yield from self.process_batch(batch, write_many=write_many)
yield from self.process_batch(batch)

def read(self):
"""Recursively read the entries."""
Expand Down
4 changes: 3 additions & 1 deletion invenio_vocabularies/datastreams/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ def create(cls, readers_config, writers_config, transformers_config=None, **kwar
for t_conf in transformers_config:
transformers.append(TransformerFactory.create(t_conf))

return DataStream(readers=readers, writers=writers, transformers=transformers)
return DataStream(
readers=readers, writers=writers, transformers=transformers, **kwargs
)
8 changes: 6 additions & 2 deletions invenio_vocabularies/datastreams/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path

import yaml
from flask import current_app
from invenio_access.permissions import system_identity
from invenio_pidstore.errors import PIDAlreadyExists, PIDDoesNotExistError
from invenio_records.systemfields.relations.errors import InvalidRelationValue
Expand Down Expand Up @@ -120,11 +121,14 @@ def write(self, stream_entry, *args, **kwargs):

def write_many(self, stream_entries, *args, **kwargs):
"""Writes the input entries using a given service."""
current_app.logger.info(f"Writing {len(stream_entries)} entries")
entries = [entry.entry for entry in stream_entries]
entries_with_id = [(self._entry_id(entry), entry) for entry in entries]
results = self._service.create_or_update_many(self._identity, entries_with_id)
result_list = self._service.create_or_update_many(
self._identity, entries_with_id
)
stream_entries_processed = []
for entry, result in zip(entries, results):
for entry, result in zip(entries, result_list.results):
processed_stream_entry = StreamEntry(
entry=entry,
record=result.record,
Expand Down
2 changes: 2 additions & 0 deletions invenio_vocabularies/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _load_vocabulary(self, config, delay=True, **kwargs):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
write_many=config.get("write_many", False),
)

errors = []
Expand Down
2 changes: 2 additions & 0 deletions invenio_vocabularies/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def process_datastream(config):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
write_many=config.get("write_many", False),
)

for result in ds.process():
Expand Down