From 6627551e679ec3ca669dd6ec77ec039458652913 Mon Sep 17 00:00:00 2001 From: Siddhant Goel Date: Tue, 5 Nov 2024 21:53:12 +0100 Subject: [PATCH] chore: update multiple targets test --- streaming_form_data/targets.py | 39 ++++++++++++++--- tests/test_parser.py | 80 ++++++++++++++++++---------------- 2 files changed, 74 insertions(+), 45 deletions(-) diff --git a/streaming_form_data/targets.py b/streaming_form_data/targets.py index 6c2bd66..97ea2f2 100644 --- a/streaming_form_data/targets.py +++ b/streaming_form_data/targets.py @@ -1,6 +1,6 @@ import hashlib from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union import smart_open # type: ignore @@ -66,14 +66,31 @@ class MultipleTargets(BaseTarget): """ def __init__(self, next_target: Callable): + """ + Args: + next_target: + A callable that returns a new target which should be used for the next + input of the multiple inputs allowed for the specific field + """ + self._next_target = next_target - self._targets = [] + self._targets: list[BaseTarget] = [] self._validator = None # next_target should have a validator + self._next_multipart_filename: Optional[str] = None + self._next_multipart_content_type: Optional[str] = None + def on_start(self): target = self._next_target() + if self._next_multipart_filename is not None: + target.set_multipart_filename(self._next_multipart_filename) + self._next_multipart_filename = None + if self._next_multipart_content_type is not None: + target.set_multipart_filename(self._next_multipart_content_type) + self._next_multipart_content_type = None + self._targets.append(target) target.start() @@ -84,10 +101,10 @@ def on_finish(self): self._targets[-1].finish() def set_multipart_filename(self, filename: str): - self._targets[-1].set_multipart_filename(filename) + self._next_multipart_filename = filename def set_multipart_content_type(self, content_type: str): - self._targets[-1].set_multipart_content_type(content_type) + self._next_multipart_content_type = content_type class NullTarget(BaseTarget): @@ -171,7 +188,11 @@ class FileTarget(BaseTarget): """ def __init__( - self, filename: str | Callable, allow_overwrite: bool = True, *args, **kwargs + self, + filename: Union[str, Callable], + allow_overwrite: bool = True, + *args, + **kwargs, ): """ Args: @@ -208,7 +229,7 @@ class DirectoryTarget(BaseTarget): def __init__( self, - directory_path: str | Callable, + directory_path: Union[str, Callable], allow_overwrite: bool = True, *args, **kwargs, @@ -278,7 +299,11 @@ class SmartOpenTarget(BaseTarget): """ def __init__( - self, file_path: str | Callable, mode: str, transport_params=None, **kwargs + self, + file_path: Union[str, Callable], + mode: str, + transport_params=None, + **kwargs, ): """ Args: diff --git a/tests/test_parser.py b/tests/test_parser.py index 5b9f365..1ccb4ff 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -12,6 +12,7 @@ DirectoryTarget, SHA256Target, ValueTarget, + MultipleTargets, ) from streaming_form_data.validators import MaxSizeValidator, ValidationError @@ -126,44 +127,6 @@ def test_basic_multiple(): assert third.value == b"baz" -def test_multiple_inputs(): - data = b"""\ ------------------------------111217161535371856203688828 -Content-Disposition: form-data; name="files"; filename="first.txt" -Content-Type: text/plain - -first - ------------------------------111217161535371856203688828 -Content-Disposition: form-data; name="files"; filename="second.txt" -Content-Type: text/plain - -second - ------------------------------111217161535371856203688828 -Content-Disposition: form-data; name="files"; filename="third.txt" -Content-Type: text/plain - -third - ------------------------------111217161535371856203688828-- -""".replace(b"\n", b"\r\n") - - target = ValueTarget() - - parser = StreamingFormDataParser( - headers={ - "Content-Type": "multipart/form-data; boundary=111217161535371856203688828" - } - ) - parser.register("files") - - parser.data_received(data) - - breakpoint() - assert target.multipart_filename == "ab.txt" - - def test_chunked_single(): expected_value = "hello world" @@ -878,6 +841,47 @@ def test_extra_headers(): assert target.value == b"Joe owes =80100." +def test_multiple_inputs(tmp_path): + data = b"""\ +--111217161535371856203688828 +Content-Disposition: form-data; name="files"; filename="first.txt" +Content-Type: text/plain + +first +--111217161535371856203688828 +Content-Disposition: form-data; name="files"; filename="second.txt" +Content-Type: text/plain + +second +--111217161535371856203688828 +Content-Disposition: form-data; name="files"; filename="third.txt" +Content-Type: text/plain + +third +--111217161535371856203688828-- +""".replace(b"\n", b"\r\n") + + class next_target: + def __init__(self): + self._index = 0 + + def __call__(self): + return FileTarget(tmp_path / f"{self._index}.txt") + + target = MultipleTargets(next_target()) + + parser = StreamingFormDataParser( + headers={ + "Content-Type": "multipart/form-data; boundary=111217161535371856203688828" + } + ) + parser.register("files", target) + + parser.data_received(data) + + assert len(target._targets) == 3 + + def test_case_insensitive_content_disposition_header(): content_disposition_header = "Content-Disposition"