diff --git a/multipart/__init__.py b/multipart/__init__.py index b49d100..4844321 100644 --- a/multipart/__init__.py +++ b/multipart/__init__.py @@ -13,3 +13,12 @@ create_form_parser, parse_form, ) + +__all__ = [ + "FormParser", + "MultipartParser", + "OctetStreamParser", + "QuerystringParser", + "create_form_parser", + "parse_form", +] diff --git a/multipart/decoders.py b/multipart/decoders.py index 0d7ab32..1ef9583 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -1,5 +1,7 @@ import base64 import binascii +from io import IOBase +from typing import overload from .exceptions import DecodeError @@ -33,10 +35,16 @@ class Base64Decoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: IOBase): self.cache = bytearray() self.underlying = underlying + @overload + def write(self, data: str) -> int: ... + + @overload + def write(self, data: bytes) -> int: ... + def write(self, data): """Takes any input data provided, decodes it as base64, and passes it on to the underlying object. If the data provided is invalid base64 @@ -95,8 +103,8 @@ def finalize(self): 'Base64Decoder cache when finalize() is called' % len(self.cache)) - if hasattr(self.underlying, 'finalize'): - self.underlying.finalize() + if hasattr(self.underlying, 'finalize') and callable(getattr(self.underlying, 'finalize')): + self.underlying.finalize() # type:ignore [reportGeneralTypeIssues] def __repr__(self): return f"{self.__class__.__name__}(underlying={self.underlying!r})" @@ -111,10 +119,16 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: IOBase): self.cache = b'' self.underlying = underlying + @overload + def write(self, data: str) -> int: ... + + @overload + def write(self, data: bytes) -> int: ... + def write(self, data): """Takes any input data provided, decodes it as quoted-printable, and passes it on to the underlying object. @@ -164,8 +178,8 @@ def finalize(self): self.cache = b'' # Finalize our underlying stream. - if hasattr(self.underlying, 'finalize'): - self.underlying.finalize() + if hasattr(self.underlying, 'finalize') and callable(getattr(self.underlying, 'finalize')): + self.underlying.finalize() # type:ignore [reportGeneralTypeIssues] def __repr__(self): return f"{self.__class__.__name__}(underlying={self.underlying!r})" diff --git a/multipart/multipart.py b/multipart/multipart.py index a9f1f9f..99be084 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,5 +1,7 @@ -from .decoders import * -from .exceptions import * +from __future__ import annotations + +from .decoders import Base64Decoder, QuotedPrintableDecoder +from .exceptions import FormParserError, MultipartParseError, QuerystringParseError, FileError import os import re @@ -9,9 +11,9 @@ import tempfile from io import BytesIO from numbers import Number +from typing import overload, Generic, TypeVar -# Unique missing object. -_missing = object() +T = TypeVar("T") # States for the querystring parser. STATE_BEFORE_FIELD = 0 @@ -76,7 +78,7 @@ QUOTE = b'"'[0] -def parse_options_header(value): +def parse_options_header(value: Union[str, bytes]): """ Parses a Content-Type header into a value in the following format: (content_type, {parameters}) @@ -117,6 +119,35 @@ def parse_options_header(value): return ctype, options +class Cache(Generic[T]): + __slots__ = ('_value', '_is_set') + + def __init__(self): + self._value: Optional[T] = None + self._is_set: bool = False + + @property + def value(self) -> Optional[T]: + if self.is_set: + return self._value + else: + raise ValueError("Value not yet set") + + @value.setter + def value(self, v: T): + self._value = v + self._is_set = True + + def clear(self): + """Reset the cache""" + self._value = None + self._is_set = False + + def is_set(self) -> bool: + """Check if value has been set""" + return self._is_set + + class Field: """A Field object represents a (parsed) form field. It represents a single field with a corresponding name and value. @@ -132,15 +163,17 @@ class Field: :param name: the name of the form field """ - def __init__(self, name): + __slots__ = ('_name', '_value', '_cache') + + def __init__(self, name: Union[bytes, str]): self._name = name - self._value = [] + self._value: List[bytes] = [] # We cache the joined version of _value for speed. - self._cache = _missing + self._cache = Cache() @classmethod - def from_value(klass, name, value): + def from_value(cls, name: Union[bytes, str], value: Optional[bytes]): """Create an instance of a :class:`Field`, and set the corresponding value - either None or an actual value. This method will also finalize the Field itself. @@ -150,7 +183,7 @@ def from_value(klass, name, value): None """ - f = klass(name) + f = cls(name) if value is None: f.set_none() else: @@ -158,28 +191,28 @@ def from_value(klass, name, value): f.finalize() return f - def write(self, data): + def write(self, data: bytes) -> int: """Write some data into the form field. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the Field. :param data: a bytestring """ self._value.append(data) - self._cache = _missing + self._cache.clear() return len(data) def on_end(self): """This method is called whenever the Field is finalized. """ - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) def finalize(self): """Finalize the form field. @@ -190,8 +223,8 @@ def close(self): """Close the Field object. This will free any underlying cache. """ # Free our value array. - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) del self._value @@ -202,7 +235,7 @@ def set_none(self): with name "baz" and value "asdf". Since the write() interface doesn't support writing None, this function will set the field value to None. """ - self._cache = None + self._cache.value = None @property def field_name(self): @@ -210,12 +243,12 @@ def field_name(self): return self._name @property - def value(self): + def value(self) -> Optional[bytes]: """This property returns the value of the form field.""" - if self._cache is _missing: - self._cache = b''.join(self._value) + if not self._cache.is_set(): + self._cache.value = b''.join(self._value) - return self._cache + return self._cache.value def __eq__(self, other): if isinstance(other, Field): @@ -227,7 +260,7 @@ def __eq__(self, other): return NotImplemented def __repr__(self): - if len(self.value) > 97: + if self.value is not None and len(self.value) > 97: # We get the repr, and then insert three dots before the final # quote. v = repr(self.value[:97])[:-1] + "...'" @@ -300,10 +333,10 @@ class File: :param config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ - def __init__(self, file_name, field_name=None, config={}): + def __init__(self, file_name, field_name=None, config=None): # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) - self._config = config + self._config = {} if config is None else config self._in_memory = True self._bytes_written = 0 self._fileobj = BytesIO() @@ -424,7 +457,7 @@ def _get_disk_file(self): try: self.logger.info("Opening file: %r", path) tmp_file = open(path, 'w+b') - except OSError as e: + except OSError: tmp_file = None self.logger.exception("Error opening temporary file") @@ -466,14 +499,14 @@ def _get_disk_file(self): self._actual_file_name = fname return tmp_file - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the File. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the File. @@ -554,6 +587,7 @@ class BaseParser: """ def __init__(self): self.logger = logging.getLogger(__name__) + self.callbacks: Dict[Callable] = {} def callback(self, name, data=None, start=None, end=None): """This function calls a provided callback with some data. If the @@ -649,7 +683,7 @@ def __init__(self, callbacks={}, max_size=float('inf')): self.max_size = max_size self._current_size = 0 - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then pass the data to the underlying callback. @@ -748,7 +782,7 @@ def __init__(self, callbacks={}, strict_parsing=False, # Should parsing be strict? self.strict_parsing = strict_parsing - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, parse into either a field name or value, and then pass the corresponding data to the underlying callback. If an error is @@ -1032,7 +1066,7 @@ def __init__(self, boundary, callbacks={}, max_size=float('inf')): # '--\r\n' is 8 bytes. self.lookbehind = [NULL for x in range(len(boundary) + 8)] - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then parse the data into the appropriate location (e.g. header, data, etc.), and pass this on to the underlying callback. If an error @@ -1567,7 +1601,7 @@ class vars: def on_start(): vars.f = FileClass(file_name, None, config=self.config) - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int): vars.f.write(data[start:end]) def on_end(): @@ -1765,7 +1799,7 @@ def on_end(): self.parser = parser - def write(self, data): + def write(self, data: bytes) -> int: """Write some data. The parser will forward this to the appropriate underlying parser. @@ -1793,8 +1827,8 @@ def __repr__(self): ) -def create_form_parser(headers, on_field, on_file, trust_x_headers=False, - config={}): +def create_form_parser(headers: Mapping, on_field: Callable, on_file: Callable, trust_x_headers: bool=False, + config={}) -> FormParser: """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1842,7 +1876,7 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, return form_parser -def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, +def parse_form(headers: Mapping, input_stream: BytesIO, on_field: Callable, on_file: Callable, chunk_size: int=1048576, **kwargs): """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's @@ -1866,19 +1900,22 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # Create our form parser. parser = create_form_parser(headers, on_field, on_file) - # Read chunks of 100KiB and write to the parser, but never read more than + # Read chunks of chunk_size and write to the parser, but never read more than # the given Content-Length, if any. content_length = headers.get('Content-Length') - if content_length is not None: - content_length = int(content_length) + if content_length is None: + calculate_max_readable_bytes = lambda _: chunk_size else: - content_length = float('inf') + content_length = int(content_length) + calculate_max_readable_bytes = lambda bytes_read: min(content_length - bytes_read, chunk_size) + bytes_read = 0 while True: # Read only up to the Content-Length given. - max_readable = min(content_length - bytes_read, 1048576) - buff = input_stream.read(max_readable) + max_readable_bytes = calculate_max_readable_bytes(bytes_read) + # TODO: why not simply use f.read(chunk_size) and let it reach EOF by itself? + buff = input_stream.read(max_readable_bytes) # Write to the parser and update our length. parser.write(buff) @@ -1886,7 +1923,7 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # If we get a buffer that's smaller than the size requested, or if we # have read up to our content length, we're done. - if len(buff) != max_readable or bytes_read == content_length: + if len(buff) != max_readable_bytes or bytes_read == content_length: break # Tell our parser that we're done writing data. diff --git a/pyproject.toml b/pyproject.toml index 1ad20d2..a8c7df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,7 @@ packages = ["multipart"] [tool.hatch.build.targets.sdist] include = ["/multipart", "/tests"] + +[tool.pyright] +include = ["multipart", "tests"] +reportUndefinedVariable = false # TODO: required because pyright does not work with __future__.annotations diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 031515b..a883431 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -12,7 +12,26 @@ from io import BytesIO from unittest.mock import Mock -from multipart.multipart import * +from multipart.multipart import ( + BaseParser, + Field, + File, + FormParser, + MultipartParser, + OctetStreamParser, + QuerystringParser, + create_form_parser, + parse_form, + parse_options_header, +) +from multipart.decoders import Base64Decoder, QuotedPrintableDecoder +from multipart.exceptions import ( + DecodeError, + FileError, + FormParserError, + MultipartParseError, + QuerystringParseError, +) # Get the current directory for our later test cases.