Skip to content

Commit

Permalink
Merge pull request #45 from ESGF/fix-leak
Browse files Browse the repository at this point in the history
Fix memory increase during download (leak)
  • Loading branch information
svenrdz authored Jul 9, 2024
2 parents f2a0ee0 + eda9d21 commit e2277de
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
54 changes: 26 additions & 28 deletions esgpull/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,13 @@ class Task:
def __init__(
self,
config: Config,
auth: Auth,
fs: Filesystem,
# *,
# url: str | None = None,
file: File,
start_callbacks: list[Callback] | None = None,
) -> None:
self.config = config
self.auth = auth
self.fs = fs
self.ctx = DownloadCtx(file)
if not self.config.download.disable_checksum:
Expand All @@ -61,16 +59,6 @@ def __init__(
# else:
# raise ValueError("no arguments")
self.downloader = Simple()
msg: str | None = None
if not default_ssl_context_loaded:
msg = load_default_ssl_context()
self.ssl_context: ssl.SSLContext | bool
if self.config.download.disable_ssl:
self.ssl_context = False
else:
if msg is not None:
logger.info(msg)
self.ssl_context = default_ssl_context
if start_callbacks is None:
self.start_callbacks = []
else:
Expand All @@ -93,20 +81,13 @@ def file(self) -> File:
# raise ValueError(f"{url} is not valid")

async def stream(
self, semaphore: asyncio.Semaphore
self,
semaphore: asyncio.Semaphore,
client: AsyncClient,
) -> AsyncIterator[Result]:
ctx = self.ctx
try:
async with (
semaphore,
self.fs.open(ctx.file) as file_obj,
AsyncClient(
follow_redirects=True,
cert=self.auth.cert,
verify=self.ssl_context,
timeout=self.config.download.http_timeout,
) as client,
):
async with semaphore, self.fs.open(ctx.file) as file_obj:
for callback in self.start_callbacks:
callback()
stream = self.downloader.stream(
Expand All @@ -117,6 +98,7 @@ async def stream(
async for ctx in stream:
if ctx.chunk is not None:
await file_obj.write(ctx.chunk)
ctx.chunk = None
if ctx.error:
err = DownloadSizeError(ctx.completed, ctx.file.size)
yield Err(ctx, err)
Expand Down Expand Up @@ -146,13 +128,23 @@ def __init__(
start_callbacks: dict[str, list[Callback]],
) -> None:
self.config = config
self.auth = auth
self.fs = fs
self.files = list(filter(self.should_download, files))
self.tasks = []
msg: str | None = None
if not default_ssl_context_loaded:
msg = load_default_ssl_context()
self.ssl_context: ssl.SSLContext | bool
if self.config.download.disable_ssl:
self.ssl_context = False
else:
if msg is not None:
logger.info(msg)
self.ssl_context = default_ssl_context
for file in files:
task = Task(
config=config,
auth=auth,
fs=fs,
file=file,
start_callbacks=start_callbacks[file.sha],
Expand All @@ -167,7 +159,13 @@ def should_download(self, file: File) -> bool:

async def process(self) -> AsyncIterator[Result]:
semaphore = asyncio.Semaphore(self.config.download.max_concurrent)
streams = [task.stream(semaphore) for task in self.tasks]
async with merge(*streams).stream() as stream:
async for result in stream:
yield result
async with AsyncClient(
follow_redirects=True,
cert=self.auth.cert,
verify=self.ssl_context,
timeout=self.config.download.http_timeout,
) as client:
streams = [task.stream(semaphore, client) for task in self.tasks]
async with merge(*streams).stream() as stream:
async for result in stream:
yield result
17 changes: 6 additions & 11 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import httpx
import pytest

from esgpull.auth import Auth
from esgpull.fs import FileCheck, Filesystem
from esgpull.models import File
from esgpull.processor import Task
Expand Down Expand Up @@ -58,27 +57,23 @@ def fs(config):


@pytest.fixture
def auth(config):
return Auth.from_config(config)


@pytest.fixture
def task(config, auth, fs, smallfile):
return Task(config, auth, fs, file=smallfile)
def task(config, fs, smallfile):
return Task(config, fs, file=smallfile)


async def run_task(task_):
semaphore = asyncio.Semaphore(1)
async for result in task_.stream(semaphore):
...
async with httpx.AsyncClient() as client:
async for result in task_.stream(semaphore, client):
...
return result


@pytest.mark.xfail(
raises=(httpx.ConnectTimeout, httpx.ReadTimeout),
reason="this is dependent on the IPSL data node's health (unstable)",
)
def test_task(auth, fs, smallfile, task):
def test_task(fs, smallfile, task):
result = asyncio.run(run_task(task))
if not result.ok:
raise result.err
Expand Down

0 comments on commit e2277de

Please sign in to comment.