From 7b477d4e7e3ef05f92cf3a95e86f9befc83c9a15 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 14:50:41 +0100 Subject: [PATCH 01/14] Use "__all__" --- multipart/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/multipart/__init__.py b/multipart/__init__.py index b49d100..4844321 100644 --- a/multipart/__init__.py +++ b/multipart/__init__.py @@ -13,3 +13,12 @@ create_form_parser, parse_form, ) + +__all__ = [ + "FormParser", + "MultipartParser", + "OctetStreamParser", + "QuerystringParser", + "create_form_parser", + "parse_form", +] From 470b26acc9c4608e35ef4862e3f12528c115e031 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 14:52:52 +0100 Subject: [PATCH 02/14] Add type hints to decoders --- multipart/decoders.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/multipart/decoders.py b/multipart/decoders.py index 0d7ab32..1ef9583 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -1,5 +1,7 @@ import base64 import binascii +from io import IOBase +from typing import overload from .exceptions import DecodeError @@ -33,10 +35,16 @@ class Base64Decoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: IOBase): self.cache = bytearray() self.underlying = underlying + @overload + def write(self, data: str) -> int: ... + + @overload + def write(self, data: bytes) -> int: ... + def write(self, data): """Takes any input data provided, decodes it as base64, and passes it on to the underlying object. If the data provided is invalid base64 @@ -95,8 +103,8 @@ def finalize(self): 'Base64Decoder cache when finalize() is called' % len(self.cache)) - if hasattr(self.underlying, 'finalize'): - self.underlying.finalize() + if hasattr(self.underlying, 'finalize') and callable(getattr(self.underlying, 'finalize')): + self.underlying.finalize() # type:ignore [reportGeneralTypeIssues] def __repr__(self): return f"{self.__class__.__name__}(underlying={self.underlying!r})" @@ -111,10 +119,16 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: IOBase): self.cache = b'' self.underlying = underlying + @overload + def write(self, data: str) -> int: ... + + @overload + def write(self, data: bytes) -> int: ... + def write(self, data): """Takes any input data provided, decodes it as quoted-printable, and passes it on to the underlying object. @@ -164,8 +178,8 @@ def finalize(self): self.cache = b'' # Finalize our underlying stream. - if hasattr(self.underlying, 'finalize'): - self.underlying.finalize() + if hasattr(self.underlying, 'finalize') and callable(getattr(self.underlying, 'finalize')): + self.underlying.finalize() # type:ignore [reportGeneralTypeIssues] def __repr__(self): return f"{self.__class__.__name__}(underlying={self.underlying!r})" From 743165d2599c662a4c2481d8115f301f0fecdcb2 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 14:57:56 +0100 Subject: [PATCH 03/14] Add overload signatures to parse_options_header --- multipart/multipart.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/multipart/multipart.py b/multipart/multipart.py index a9f1f9f..914520a 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -9,6 +9,7 @@ import tempfile from io import BytesIO from numbers import Number +from typing import overload, Tuple # Unique missing object. _missing = object() @@ -76,6 +77,12 @@ QUOTE = b'"'[0] +@overload +def parse_options_header(value: str) -> Tuple[bytes, dict[bytes, bytes]]: ... + +@overload +def parse_options_header(value: bytes) -> Tuple[bytes, dict[bytes, bytes]]: ... + def parse_options_header(value): """ Parses a Content-Type header into a value in the following format: From 9adde7ca0822d9d2e141b0522005a08bafc0d2cf Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 15:52:54 +0100 Subject: [PATCH 04/14] Use Dict instead of dict for compatibility with Python <3.9 --- multipart/multipart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 914520a..59809ce 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -9,7 +9,7 @@ import tempfile from io import BytesIO from numbers import Number -from typing import overload, Tuple +from typing import overload, Dict, Tuple # Unique missing object. _missing = object() @@ -78,10 +78,10 @@ @overload -def parse_options_header(value: str) -> Tuple[bytes, dict[bytes, bytes]]: ... +def parse_options_header(value: str) -> Tuple[bytes, Dict[bytes, bytes]]: ... @overload -def parse_options_header(value: bytes) -> Tuple[bytes, dict[bytes, bytes]]: ... +def parse_options_header(value: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: ... def parse_options_header(value): """ From 72238e3c8b7b32907313784b1ee252b0a798a98e Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 17:13:37 +0100 Subject: [PATCH 05/14] Type hint "write" methods --- multipart/multipart.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 59809ce..d09e65c 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -165,14 +165,14 @@ def from_value(klass, name, value): f.finalize() return f - def write(self, data): + def write(self, data: bytes) -> int: """Write some data into the form field. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the Field. @@ -473,14 +473,14 @@ def _get_disk_file(self): self._actual_file_name = fname return tmp_file - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the File. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the File. @@ -656,7 +656,7 @@ def __init__(self, callbacks={}, max_size=float('inf')): self.max_size = max_size self._current_size = 0 - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then pass the data to the underlying callback. @@ -755,7 +755,7 @@ def __init__(self, callbacks={}, strict_parsing=False, # Should parsing be strict? self.strict_parsing = strict_parsing - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, parse into either a field name or value, and then pass the corresponding data to the underlying callback. If an error is @@ -1039,7 +1039,7 @@ def __init__(self, boundary, callbacks={}, max_size=float('inf')): # '--\r\n' is 8 bytes. self.lookbehind = [NULL for x in range(len(boundary) + 8)] - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then parse the data into the appropriate location (e.g. header, data, etc.), and pass this on to the underlying callback. If an error @@ -1574,7 +1574,7 @@ class vars: def on_start(): vars.f = FileClass(file_name, None, config=self.config) - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int): vars.f.write(data[start:end]) def on_end(): @@ -1772,7 +1772,7 @@ def on_end(): self.parser = parser - def write(self, data): + def write(self, data: bytes) -> int: """Write some data. The parser will forward this to the appropriate underlying parser. From 83f5914e594775480983c46bbab48a3006504edf Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 17:25:43 +0100 Subject: [PATCH 06/14] Use "cls". Assume name can be string or bytestring in type-hints --- multipart/multipart.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index d09e65c..b13b809 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -9,7 +9,7 @@ import tempfile from io import BytesIO from numbers import Number -from typing import overload, Dict, Tuple +from typing import overload, Dict, List, Optional, Tuple, Union # Unique missing object. _missing = object() @@ -139,15 +139,15 @@ class Field: :param name: the name of the form field """ - def __init__(self, name): + def __init__(self, name: Union[bytes, str]): self._name = name - self._value = [] + self._value: List[bytes] = [] # We cache the joined version of _value for speed. self._cache = _missing @classmethod - def from_value(klass, name, value): + def from_value(cls, name: Union[bytes, str], value: Optional[bytes]): """Create an instance of a :class:`Field`, and set the corresponding value - either None or an actual value. This method will also finalize the Field itself. @@ -157,7 +157,7 @@ def from_value(klass, name, value): None """ - f = klass(name) + f = cls(name) if value is None: f.set_none() else: From dbd31f2dcba875901dbf7d7f1e41ed6d818c3bd9 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 19:52:38 +0100 Subject: [PATCH 07/14] Create object of cache. Use specific imports --- multipart/multipart.py | 60 +++++++++++++++++++++++++++++------------ tests/test_multipart.py | 21 ++++++++++++++- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index b13b809..d6a8c8e 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,5 +1,5 @@ -from .decoders import * -from .exceptions import * +from .decoders import Base64Decoder, QuotedPrintableDecoder +from .exceptions import FormParserError, MultipartParseError, QuerystringParseError, FileError import os import re @@ -9,10 +9,9 @@ import tempfile from io import BytesIO from numbers import Number -from typing import overload, Dict, List, Optional, Tuple, Union +from typing import overload, Dict, Generic, List, Optional, Tuple, TypeVar, Union -# Unique missing object. -_missing = object() +T = TypeVar("T") # States for the querystring parser. STATE_BEFORE_FIELD = 0 @@ -124,6 +123,33 @@ def parse_options_header(value): return ctype, options +class Cache(Generic[T]): + def __init__(self): + self._value: Optional[T] = None + self._is_set: bool = False + + @property + def value(self) -> Optional[T]: + if self.is_set: + return self._value + else: + raise ValueError("Value not yet set") + + @value.setter + def value(self, v: T): + self._value = v + self._is_set = True + + def clear(self): + """Reset the cache""" + self._value = None + self._is_set = False + + def is_set(self) -> bool: + """Check if value has been set""" + return self._is_set + + class Field: """A Field object represents a (parsed) form field. It represents a single field with a corresponding name and value. @@ -144,7 +170,7 @@ def __init__(self, name: Union[bytes, str]): self._value: List[bytes] = [] # We cache the joined version of _value for speed. - self._cache = _missing + self._cache = Cache() @classmethod def from_value(cls, name: Union[bytes, str], value: Optional[bytes]): @@ -179,14 +205,14 @@ def on_data(self, data: bytes) -> int: :param data: a bytestring """ self._value.append(data) - self._cache = _missing + self._cache.clear() return len(data) def on_end(self): """This method is called whenever the Field is finalized. """ - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) def finalize(self): """Finalize the form field. @@ -197,8 +223,8 @@ def close(self): """Close the Field object. This will free any underlying cache. """ # Free our value array. - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) del self._value @@ -209,7 +235,7 @@ def set_none(self): with name "baz" and value "asdf". Since the write() interface doesn't support writing None, this function will set the field value to None. """ - self._cache = None + self._cache.value = None @property def field_name(self): @@ -217,12 +243,12 @@ def field_name(self): return self._name @property - def value(self): + def value(self) -> Optional[bytes]: """This property returns the value of the form field.""" - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) - return self._cache + return self._cache.value def __eq__(self, other): if isinstance(other, Field): @@ -234,7 +260,7 @@ def __eq__(self, other): return NotImplemented def __repr__(self): - if len(self.value) > 97: + if self.value is not None and len(self.value) > 97: # We get the repr, and then insert three dots before the final # quote. v = repr(self.value[:97])[:-1] + "...'" diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 031515b..a883431 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -12,7 +12,26 @@ from io import BytesIO from unittest.mock import Mock -from multipart.multipart import * +from multipart.multipart import ( + BaseParser, + Field, + File, + FormParser, + MultipartParser, + OctetStreamParser, + QuerystringParser, + create_form_parser, + parse_form, + parse_options_header, +) +from multipart.decoders import Base64Decoder, QuotedPrintableDecoder +from multipart.exceptions import ( + DecodeError, + FileError, + FormParserError, + MultipartParseError, + QuerystringParseError, +) # Get the current directory for our later test cases. From 3b0eef8d039a07690542939d99df4c4078dc574d Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sun, 21 Jan 2024 21:04:18 +0100 Subject: [PATCH 08/14] Use __slots__ to reduce memory footprint Cache and Field --- multipart/multipart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/multipart/multipart.py b/multipart/multipart.py index d6a8c8e..94ccbe3 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -124,6 +124,8 @@ def parse_options_header(value): class Cache(Generic[T]): + __slots__ = ('_value', '_is_set') + def __init__(self): self._value: Optional[T] = None self._is_set: bool = False @@ -165,6 +167,8 @@ class Field: :param name: the name of the form field """ + __slots__ = ('_name', '_value', '_cache') + def __init__(self, name: Union[bytes, str]): self._name = name self._value: List[bytes] = [] From 64f5c1ceadcc65522a010c38e355e624c713df76 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:25:45 +0100 Subject: [PATCH 09/14] Added minimal pyright configuration to pyproject.toml --- multipart/multipart.py | 4 +++- pyproject.toml | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 94ccbe3..328c522 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .decoders import Base64Decoder, QuotedPrintableDecoder from .exceptions import FormParserError, MultipartParseError, QuerystringParseError, FileError @@ -9,7 +11,7 @@ import tempfile from io import BytesIO from numbers import Number -from typing import overload, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import overload, Generic, TypeVar T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index 1ad20d2..e911e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,10 @@ packages = ["multipart"] [tool.hatch.build.targets.sdist] include = ["/multipart", "/tests"] + +[tool.pyright] +include = ["multipart", "tests"] +exclude = [ + "**/__pycache__", +] +reportUndefinedVariable = false # TODO: required because pyright does not work with __future__.annotations From 28b177311138f65168e02afed9e02a58d6269c5f Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:31:18 +0100 Subject: [PATCH 10/14] Avoid mutable default argument --- multipart/multipart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 328c522..e05b68f 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -339,10 +339,10 @@ class File: :param config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ - def __init__(self, file_name, field_name=None, config={}): + def __init__(self, file_name, field_name=None, config=None): # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) - self._config = config + self._config = {} if config is None else config self._in_memory = True self._bytes_written = 0 self._fileobj = BytesIO() From dd85408e10cb9883e8abc1a70b8e6a32b35e9461 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Thu, 25 Jan 2024 22:12:34 +0100 Subject: [PATCH 11/14] Why not use f.read(chunk_size) regardless of Content-Length and let it reach EOF? --- multipart/multipart.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index e05b68f..c0eca42 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -593,6 +593,7 @@ class BaseParser: """ def __init__(self): self.logger = logging.getLogger(__name__) + self.callbacks: Dict[Callable] = {} def callback(self, name, data=None, start=None, end=None): """This function calls a provided callback with some data. If the @@ -1832,8 +1833,8 @@ def __repr__(self): ) -def create_form_parser(headers, on_field, on_file, trust_x_headers=False, - config={}): +def create_form_parser(headers: Mapping, on_field: Callable, on_file: Callable, trust_x_headers: bool=False, + config={}) -> FormParser: """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1881,7 +1882,7 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, return form_parser -def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, +def parse_form(headers: Mapping, input_stream: BytesIO, on_field: Callable, on_file: Callable, chunk_size: int=1048576, **kwargs): """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's @@ -1905,19 +1906,22 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # Create our form parser. parser = create_form_parser(headers, on_field, on_file) - # Read chunks of 100KiB and write to the parser, but never read more than + # Read chunks of chunk_size and write to the parser, but never read more than # the given Content-Length, if any. content_length = headers.get('Content-Length') - if content_length is not None: - content_length = int(content_length) + if content_length is None: + calculate_max_readable_bytes = lambda _: chunk_size else: - content_length = float('inf') + content_length = int(content_length) + calculate_max_readable_bytes = lambda bytes_read: min(content_length - bytes_read, chunk_size) + bytes_read = 0 while True: # Read only up to the Content-Length given. - max_readable = min(content_length - bytes_read, 1048576) - buff = input_stream.read(max_readable) + max_readable_bytes = calculate_max_readable_bytes(bytes_read) + # TODO: why not simply use f.read(chunk_size) and let it reach EOF by itself? + buff = input_stream.read(max_readable_bytes) # Write to the parser and update our length. parser.write(buff) @@ -1925,7 +1929,7 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # If we get a buffer that's smaller than the size requested, or if we # have read up to our content length, we're done. - if len(buff) != max_readable or bytes_read == content_length: + if len(buff) != max_readable_bytes or bytes_read == content_length: break # Tell our parser that we're done writing data. From 4bd44c1cac7f74d71aa6ba31c36802da72f84d31 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:54:41 +0100 Subject: [PATCH 12/14] remove unused alias --- multipart/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index c0eca42..7630367 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -463,7 +463,7 @@ def _get_disk_file(self): try: self.logger.info("Opening file: %r", path) tmp_file = open(path, 'w+b') - except OSError as e: + except OSError: tmp_file = None self.logger.exception("Error opening temporary file") From fb2a4f452bc25df4b18da0b0191c12bd87d940ce Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:55:01 +0100 Subject: [PATCH 13/14] remove "pyright.exclude" --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e911e1d..a8c7df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,4 @@ include = ["/multipart", "/tests"] [tool.pyright] include = ["multipart", "tests"] -exclude = [ - "**/__pycache__", -] reportUndefinedVariable = false # TODO: required because pyright does not work with __future__.annotations From 871a91e8e7e542d46fde5cdb4b5c91e93c41d9c2 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Mon, 5 Feb 2024 02:13:25 +0100 Subject: [PATCH 14/14] No need to use overload here --- multipart/multipart.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 7630367..99be084 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -78,13 +78,7 @@ QUOTE = b'"'[0] -@overload -def parse_options_header(value: str) -> Tuple[bytes, Dict[bytes, bytes]]: ... - -@overload -def parse_options_header(value: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: ... - -def parse_options_header(value): +def parse_options_header(value: Union[str, bytes]): """ Parses a Content-Type header into a value in the following format: (content_type, {parameters})