diff --git a/CHANGES.rst b/CHANGES.rst index 5019208e4..65828cedb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,6 +9,8 @@ Unreleased - Use modern packaging metadata with ``pyproject.toml`` instead of ``setup.cfg``. :pr:`1793` - Use ``flit_core`` instead of ``setuptools`` as build backend. +- Add support for ``pathlib.Path`` and ``os.Pathlike`` objects to + ``TemplateStream.dump``. :issue:`2039` Version 3.1.5 diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py index a99cc2d1d..533350478 100644 --- a/src/jinja2/environment.py +++ b/src/jinja2/environment.py @@ -61,6 +61,8 @@ from .ext import Extension from .loaders import BaseLoader + StrOrBytesPath = t.Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] + _env_bound = t.TypeVar("_env_bound", bound="Environment") @@ -1592,7 +1594,7 @@ def __init__(self, gen: t.Iterator[str]) -> None: def dump( self, - fp: t.Union[str, t.IO[bytes]], + fp: t.Union["StrOrBytesPath", t.IO[bytes]], encoding: t.Optional[str] = None, errors: t.Optional[str] = "strict", ) -> None: @@ -1604,17 +1606,20 @@ def dump( Template('Hello {{ name }}!').stream(name='foo').dump('hello.html') """ + real_fp: t.IO[bytes] + close = False - if isinstance(fp, str): + try: + real_fp = open(fp, "wb") # type: ignore[arg-type] + except TypeError: + real_fp = fp # type: ignore[assignment] + else: + close = True + if encoding is None: encoding = "utf-8" - real_fp: t.IO[bytes] = open(fp, "wb") - close = True - else: - real_fp = fp - try: if encoding is not None: iterable = (x.encode(encoding, errors) for x in self) # type: ignore @@ -1623,9 +1628,15 @@ def dump( if hasattr(real_fp, "writelines"): real_fp.writelines(iterable) - else: + elif hasattr(real_fp, "write"): for item in iterable: real_fp.write(item) + else: + raise AttributeError( + f"'{real_fp.__class__.__name__}' object has no attribute" + f" 'write' or 'writelines'" + ) + finally: if close: real_fp.close() diff --git a/tests/test_api.py b/tests/test_api.py index 4472b85ac..9b2d61b3f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,8 @@ +import io +import os import shutil import tempfile +import typing as t from pathlib import Path import pytest @@ -242,16 +245,54 @@ def test_streaming_behavior(self, env): stream.disable_buffering() assert not stream.buffered - def test_dump_stream(self, env): + +class TestStreamingDump: + class CustomPathLike: + def __init__(self, path: Path): + self._path = path + + def __fspath__(self): + return self._path.__fspath__() + + class CustomWithWrite: + def __init__(self): + self._value = [] + + def write(self, value: bytes) -> None: + self._value.append(value) + + def getvalue(self) -> bytes: + return b"".join(self._value) + + class CustomWithWriteLines: + def __init__(self): + self._value = [] + + def writelines(self, value: t.Iterable[bytes]) -> None: + self._value += list(value) + + def getvalue(self) -> bytes: + return b"".join(self._value) + + @pytest.mark.parametrize("cast_to", [str, Path, os.fspath, CustomPathLike, bytes]) + def test_dump_stream_file_path(self, env, cast_to): tmp = Path(tempfile.mkdtemp()) try: tmpl = env.from_string("\u2713") stream = tmpl.stream() - stream.dump(str(tmp / "dump.txt"), "utf-8") + stream.dump(cast_to(tmp / "dump.txt"), "utf-8") assert (tmp / "dump.txt").read_bytes() == b"\xe2\x9c\x93" finally: shutil.rmtree(tmp) + @pytest.mark.parametrize("obj", [io.BytesIO, CustomWithWrite, CustomWithWriteLines]) + def test_dump_stream_io(self, env, obj): + tmpl = env.from_string("\u2713") + stream = tmpl.stream() + _io = obj() + stream.dump(_io, "utf-8") + assert _io.getvalue() == b"\xe2\x9c\x93" + class TestUndefined: def test_stopiteration_is_undefined(self):