Skip to content

Commit

Permalink
datastream: optimize memory usage on ORCID sync
Browse files Browse the repository at this point in the history
  • Loading branch information
jrcastro2 committed Dec 6, 2024
1 parent b798bf3 commit b60d386
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 39 deletions.
2 changes: 2 additions & 0 deletions invenio_vocabularies/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def _process_vocab(config, num_samples=None):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
write_many=config.get("write_many", False),
)

success, errored, filtered = 0, 0, 0
Expand Down
86 changes: 57 additions & 29 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import timedelta
from pathlib import Path
from itertools import islice

import arrow
import regex as re
Expand Down Expand Up @@ -48,6 +49,8 @@ def _fetch_orcid_data(self, orcid_to_sync, bucket):
suffix = orcid_to_sync[-3:]
key = f"{suffix}/{orcid_to_sync}.xml"
try:
# Potential improvement: use the a XML jax parser to avoid loading the whole file in memory
# and choose the sections we need to read (probably the summary)
return self.s3_client.read_file(f"s3://{bucket}/{key}")
except Exception:
current_app.logger.exception("Failed to fetch ORCiD record.")
Expand All @@ -67,42 +70,54 @@ def _process_lambda_file(self, fileobj):
if self.since:
time_shift = self.since
last_sync = arrow.now() - timedelta(**time_shift)

file_content = fileobj.read().decode("utf-8")

csv_reader = csv.DictReader(file_content.splitlines())

for row in csv_reader: # Skip the header line
orcid = row["orcid"]

# Lambda file is ordered by last modified date
last_modified_str = row["last_modified"]
try:
last_modified_date = arrow.get(last_modified_str, date_format)
except arrow.parser.ParserError:
last_modified_date = arrow.get(last_modified_str, date_format_no_millis)

if last_modified_date < last_sync:
break
yield orcid
try:
content = io.TextIOWrapper(fileobj, encoding="utf-8")
csv_reader = csv.DictReader(content)

for row in csv_reader: # Skip the header line
orcid = row["orcid"]

# Lambda file is ordered by last modified date
last_modified_str = row["last_modified"]
try:
last_modified_date = arrow.get(last_modified_str, date_format)
except arrow.parser.ParserError:
last_modified_date = arrow.get(
last_modified_str, date_format_no_millis
)

if last_modified_date < last_sync:
break
yield orcid
finally:
fileobj.close()

def _iter(self, orcids):
"""Iterates over the ORCiD records yielding each one."""
with ThreadPoolExecutor(
max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]
) as executor:
futures = [
executor.submit(
# futures is a dictionary where the key is the ORCID value and the item is the Future object
futures = {
orcid: executor.submit(
self._fetch_orcid_data,
orcid,
current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"],
)
for orcid in orcids
]
for future in as_completed(futures):
result = future.result()
if result is not None:
yield result
}

for orcid in list(futures.keys()):
try:
result = futures[orcid].result()
if result:
yield result
finally:
# Explicitly release memory, as we don't need the future anymore.
# This is mostly required because as long as we keep a reference to the future
# (in the above futures dict), the garbage collector won't collect it
# and it will keep the memory allocated.
del futures[orcid]

def read(self, item=None, *args, **kwargs):
"""Streams the ORCiD lambda file, process it to get the ORCiDS to sync and yields it's data."""
Expand All @@ -111,18 +126,31 @@ def read(self, item=None, *args, **kwargs):
"s3://orcid-lambda-file/last_modified.csv.tar"
)

orcids_to_sync = []
# Opens tar file and process it
with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar:
# Iterate over each member (file or directory) in the tar file
for member in tar.getmembers():
# Extract the file
extracted_file = tar.extractfile(member)
if extracted_file:
current_app.logger.info(f"[ORCID Reader] Processing lambda file...")
# Process the file and get the ORCiDs to sync
orcids_to_sync.extend(self._process_lambda_file(extracted_file))

yield from self._iter(orcids_to_sync)
orcids_to_sync = set(self._process_lambda_file(extracted_file))

# Close the file explicitly after processing
extracted_file.close()

# Process ORCIDs in smaller batches
for orcid_batch in self._chunked_iter(
orcids_to_sync, batch_size=100
):
yield from self._iter(orcid_batch)

def _chunked_iter(self, iterable, batch_size):
"""Yield successive chunks of a given size."""
it = iter(iterable)
while chunk := list(islice(it, batch_size)):
yield chunk


class OrcidHTTPReader(SimpleHTTPReader):
Expand Down
25 changes: 18 additions & 7 deletions invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@ def log_errors(self, logger=None):
class DataStream:
"""Data stream."""

def __init__(self, readers, writers, transformers=None, *args, **kwargs):
def __init__(
self,
readers,
writers,
transformers=None,
batch_size=100,
write_many=False,
*args,
**kwargs,
):
"""Constructor.
:param readers: an ordered list of readers.
Expand All @@ -58,12 +67,14 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs):
self._readers = readers
self._transformers = transformers
self._writers = writers
self.batch_size = batch_size
self.write_many = write_many

def filter(self, stream_entry, *args, **kwargs):
"""Checks if an stream_entry should be filtered out (skipped)."""
return False

def process_batch(self, batch, write_many=False):
def process_batch(self, batch):
"""Process a batch of entries."""
transformed_entries = []
for stream_entry in batch:
Expand All @@ -79,12 +90,12 @@ def process_batch(self, batch, write_many=False):
else:
transformed_entries.append(transformed_entry)
if transformed_entries:
if write_many:
if self.write_many:
yield from self.batch_write(transformed_entries)
else:
yield from (self.write(entry) for entry in transformed_entries)

def process(self, batch_size=100, write_many=False, *args, **kwargs):
def process(self, *args, **kwargs):
"""Iterates over the entries.
Uses the reader to get the raw entries and transforms them.
Expand All @@ -95,13 +106,13 @@ def process(self, batch_size=100, write_many=False, *args, **kwargs):
batch = []
for stream_entry in self.read():
batch.append(stream_entry)
if len(batch) >= batch_size:
yield from self.process_batch(batch, write_many=write_many)
if len(batch) >= self.batch_size:
yield from self.process_batch(batch)
batch = []

# Process any remaining entries in the last batch
if batch:
yield from self.process_batch(batch, write_many=write_many)
yield from self.process_batch(batch)

def read(self):
"""Recursively read the entries."""
Expand Down
4 changes: 3 additions & 1 deletion invenio_vocabularies/datastreams/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ def create(cls, readers_config, writers_config, transformers_config=None, **kwar
for t_conf in transformers_config:
transformers.append(TransformerFactory.create(t_conf))

return DataStream(readers=readers, writers=writers, transformers=transformers)
return DataStream(
readers=readers, writers=writers, transformers=transformers, **kwargs
)
8 changes: 6 additions & 2 deletions invenio_vocabularies/datastreams/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path

import yaml
from flask import current_app
from invenio_access.permissions import system_identity
from invenio_pidstore.errors import PIDAlreadyExists, PIDDoesNotExistError
from invenio_records.systemfields.relations.errors import InvalidRelationValue
Expand Down Expand Up @@ -120,11 +121,14 @@ def write(self, stream_entry, *args, **kwargs):

def write_many(self, stream_entries, *args, **kwargs):
"""Writes the input entries using a given service."""
current_app.logger.info(f"Writing {len(stream_entries)} entries")
entries = [entry.entry for entry in stream_entries]
entries_with_id = [(self._entry_id(entry), entry) for entry in entries]
results = self._service.create_or_update_many(self._identity, entries_with_id)
result_list = self._service.create_or_update_many(
self._identity, entries_with_id
)
stream_entries_processed = []
for entry, result in zip(entries, results):
for entry, result in zip(entries, result_list.results):
processed_stream_entry = StreamEntry(
entry=entry,
record=result.record,
Expand Down
2 changes: 2 additions & 0 deletions invenio_vocabularies/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _load_vocabulary(self, config, delay=True, **kwargs):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
write_many=config.get("write_many", False),
)

errors = []
Expand Down
2 changes: 2 additions & 0 deletions invenio_vocabularies/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def process_datastream(config):
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
batch_size=config.get("batch_size", 1000),
write_many=config.get("write_many", False),
)

for result in ds.process():
Expand Down

0 comments on commit b60d386

Please sign in to comment.