diff --git a/piptools/writer.py b/piptools/writer.py index c35125bbc..1565ad505 100644 --- a/piptools/writer.py +++ b/piptools/writer.py @@ -1,11 +1,13 @@ from __future__ import annotations +import contextlib import io import os import re import sys +from dataclasses import dataclass from itertools import chain -from typing import BinaryIO, Iterable, Iterator, cast +from typing import BinaryIO, Generator, Iterable, Iterator, cast from click import unstyle from click.core import Context @@ -73,6 +75,39 @@ def annotation_style_line(required_by: set[str]) -> str: return f"# via {', '.join(sorted(required_by))}" +@dataclass +class _LineWriter: + _io: io.TextIOWrapper + + def write(self, line: str) -> None: + log.info(line) + self._io.write(unstyle(line)) + self._io.write("\n") + + @classmethod + @contextlib.contextmanager + def create( + cls, buffer: BinaryIO, newline: str + ) -> Generator[_LineWriter, object, None]: + wrapper = io.TextIOWrapper( + buffer=buffer, + encoding="utf8", + newline=newline, + line_buffering=True, + ) + try: + yield cls(wrapper) + finally: + wrapper.detach() + + +class _DryRunWriter: + @staticmethod + def write(line: str) -> None: + # Bypass the log level to always print this during a dry run + log.log(line) + + class OutputWriter: def __init__( self, @@ -250,26 +285,17 @@ def write( markers: dict[str, Marker], hashes: dict[InstallRequirement, set[str]] | None, ) -> None: - - if not self.dry_run: - dst_file = io.TextIOWrapper( - self.dst_file, - encoding="utf8", - newline=self.linesep, - line_buffering=True, - ) - try: + cmgr: ( + contextlib.AbstractContextManager[_DryRunWriter] + | contextlib.AbstractContextManager[_LineWriter] + ) = ( + contextlib.nullcontext(_DryRunWriter()) + if self.dry_run + else _LineWriter.create(buffer=self.dst_file, newline=self.linesep) + ) + with cmgr as line_writer: for line in self._iter_lines(results, unsafe_requirements, markers, hashes): - if self.dry_run: - # Bypass the log level to always print this during a dry run - log.log(line) - else: - log.info(line) - dst_file.write(unstyle(line)) - dst_file.write("\n") - finally: - if not self.dry_run: - dst_file.detach() + line_writer.write(line) def _format_requirement( self,