Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/multipart serialization #235

Merged
merged 6 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.3.0] - 2024-02-28

### Added
- Added multipart body class to support multipart serialization. [microsoft/kiota#3030](https://github.com/microsoft/kiota/issues/3030)

### Changed

## [1.2.0] - 2024-01-31

### Added
Expand Down
2 changes: 1 addition & 1 deletion kiota_abstractions/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION: str = "1.2.0"
VERSION: str = "1.3.0"
144 changes: 144 additions & 0 deletions kiota_abstractions/multipart_body.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------
from __future__ import annotations

import io
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar

from .serialization import Parsable

if TYPE_CHECKING:
from .request_adapter import RequestAdapter
from .serialization import ParseNode, SerializationWriter

T = TypeVar("T")


@dataclass
class MultipartBody(Parsable, Generic[T]):
"""Represents a multipart body for a request or a response.
Example usage:
multipart = MultipartBody()
multipart.add_or_replace_part("file", "image/jpeg", open("image.jpg", "rb").read())
multipart.add_or_replace_part("text", "text/plain", "Hello, World!")
with open("output.txt", "w") as output_file:
multipart.serialize(output_file)
"""
boundary: str = str(uuid.uuid4())
parts: Dict[str, Tuple[str, Any]] = field(default_factory=dict)
request_adapter: Optional[RequestAdapter] = None

def add_or_replace_part(self, part_name: str, content_type: str, part_value: T) -> None:
"""Adds or replaces a part to the multipart body.

Args:
part_name (str): The name of the part to add or replace.
content_type (str): The content type of the part.
part_value (T): The value of the part.

Returns:
None
"""
if not part_name:
raise ValueError("Part name cannot be null")
if not content_type:
raise ValueError("Content type cannot be null")
if not part_value:
raise ValueError("Part value cannot be null")
value: Tuple[str, Any] = (content_type, part_value)
self.parts[self._normalize_part_name(part_name)] = value

def get_part_value(self, part_name: str) -> T:
"""Gets the value of a part from the multipart body."""
if not part_name:
raise ValueError("Part name cannot be null")
value = self.parts.get(self._normalize_part_name(part_name))
return value[1] if value else None

def remove_part(self, part_name: str) -> bool:
"""Removes a part from the multipart body.

Args:
part_name (str): The name of the part to remove.

Returns:
bool: True if the part was removed, False otherwise.
"""
if not part_name:
raise ValueError("Part name cannot be null")
return self.parts.pop(self._normalize_part_name(part_name), None) is not None

def get_field_deserializers(self) -> Dict[str, Callable[[ParseNode], None]]:
"""Gets the deserialization information for this object.

Returns:
Dict[str, Callable[[ParseNode], None]]: The deserialization information for this
object where each entry is a property key with its deserialization callback.
"""
raise NotImplementedError()

def serialize(self, writer: SerializationWriter) -> None:
"""Writes the objects properties to the current writer.

Args:
writer (SerializationWriter): The writer to write to.
"""
if not writer:
raise ValueError("Serialization writer cannot be null")
if not self.request_adapter or not self.request_adapter.get_serialization_writer_factory():
raise ValueError("Request adapter or serialization writer factory cannot be null")
if not self.parts:
raise ValueError("No parts to serialize")

first = True
for part_name, part_value in self.parts.items():
if first:
first = False
else:
self._add_new_line(writer)

writer.write_str_value("", f"--{self.boundary}")
writer.write_str_value("Content-Type", f"{part_value[0]}")
writer.write_str_value("Content-Disposition", f'form-data; name="{part_name}"')
self._add_new_line(writer)

if isinstance(part_value[1], Parsable):
self._write_parsable(writer, part_value[1])
elif isinstance(part_value[1], str):
writer.write_str_value("", part_value[1])
elif isinstance(part_value[1], bytes):
writer.write_bytes_value("", part_value[1])
elif isinstance(part_value[1], io.IOBase):
writer.write_bytes_value("", part_value[1].read())
else:
raise ValueError(f"Unsupported type {type(part_value[1])} for part {part_name}")

self._add_new_line(writer)
writer.write_str_value("", f"--{self.boundary}--")

def _normalize_part_name(self, original: str) -> str:
return original.lower()

def _add_new_line(self, writer: SerializationWriter) -> None:
writer.write_str_value("", "")

def _write_parsable(self, writer, part_value) -> None:
if not self.request_adapter or not self.request_adapter.get_serialization_writer_factory():
raise ValueError("Request adapter or serialization writer factory cannot be null")
part_writer = (
self.request_adapter.get_serialization_writer_factory().get_serialization_writer(
part_value[0]
)
)
part_writer.write_object_value("", part_value[1], None)
part_content = part_writer.get_serialized_content()
if hasattr(part_content, "seek"): # seekable
part_content.seek(0)
writer.write_bytes_value("", part_content.read()) #type: ignore
else:
writer.write_bytes_value("", part_content)
20 changes: 15 additions & 5 deletions kiota_abstractions/request_information.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------
from __future__ import annotations

from dataclasses import fields, is_dataclass
from datetime import date, datetime, time, timedelta
from enum import Enum
Expand All @@ -13,6 +20,7 @@
from .base_request_configuration import RequestConfiguration
from .headers_collection import HeadersCollection
from .method import Method
from .multipart_body import MultipartBody
from .request_option import RequestOption
from .serialization import Parsable, SerializationWriter

Expand Down Expand Up @@ -69,7 +77,7 @@ def __init__(
self.headers: HeadersCollection = HeadersCollection()

# The Request Body
self.content: Optional[BytesIO] = None
self.content: Optional[bytes] = None

def configure(self, request_configuration: RequestConfiguration) -> None:
"""Configures the current request information headers, query parameters, and options
Expand Down Expand Up @@ -145,8 +153,8 @@ def remove_request_options(self, options: List[RequestOption]) -> None:

def set_content_from_parsable(
self,
request_adapter: Optional["RequestAdapter"],
content_type: Optional[str],
request_adapter: RequestAdapter,
baywet marked this conversation as resolved.
Show resolved Hide resolved
content_type: str,
values: Union[T, List[T]],
) -> None:
"""Sets the request body from a model with the specified content type.
Expand All @@ -161,7 +169,9 @@ def set_content_from_parsable(
self._create_parent_span_name("set_content_from_parsable")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)

if isinstance(values, MultipartBody):
content_type += f"; boundary={values.boundary}"
values.request_adapter = request_adapter
if isinstance(values, list):
writer.write_collection_of_object_values(None, values)
span.set_attribute(self.REQUEST_TYPE_KEY, "[]")
Expand Down Expand Up @@ -217,7 +227,7 @@ def set_content_from_scalar(
writer_func(None, values)
self._set_content_and_content_type_header(writer, content_type)

def set_stream_content(self, value: BytesIO, content_type: Optional[str] = None) -> None:
def set_stream_content(self, value: bytes, content_type: Optional[str] = None) -> None:
"""Sets the request body to be a binary stream.

Args:
Expand Down
7 changes: 3 additions & 4 deletions kiota_abstractions/serialization/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from datetime import date, datetime, time, timedelta
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -166,11 +165,11 @@ def get_object_value(self, factory: ParsableFactory) -> Parsable:
pass

@abstractmethod
def get_bytes_value(self) -> BytesIO:
"""Get a bytearray value from the nodes
def get_bytes_value(self) -> bytes:
"""Get a bytes value from the nodes

Returns:
bytearray: The bytearray value from the nodes
bytes: The bytes value from the nodes
"""
pass

Expand Down
5 changes: 2 additions & 3 deletions kiota_abstractions/serialization/parse_node_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from io import BytesIO

from .parse_node import ParseNode

Expand All @@ -18,12 +17,12 @@ def get_valid_content_type(self) -> str:
pass

@abstractmethod
def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode:
def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode:
"""Creates a ParseNode from the given binary stream and content type

Args:
content_type (str): The content type of the binary stream
content (BytesIO): The array buffer to read from
content (bytes): The array buffer to read from

Returns:
ParseNode: A ParseNode that can deserialize the given binary stream
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import re
from io import BytesIO
from typing import Dict

from .parse_node import ParseNode
Expand Down Expand Up @@ -31,7 +30,7 @@ def get_valid_content_type(self) -> str:
"The registry supports multiple content types. Get the registered factory instead"
)

def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode:
def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode:
if not content_type:
raise Exception("Content type cannot be null")
if not content:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from io import BytesIO
from typing import Callable

from .parsable import Parsable
Expand Down Expand Up @@ -44,12 +43,12 @@ def get_valid_content_type(self) -> str:
"""
return self._concrete.get_valid_content_type()

def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode:
def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode:
"""Create a parse node from the given stream and content type.

Args:
content_type (str): The content type of the parse node.
content (BytesIO): The stream to read the parse node from.
content (bytes): The stream to read the parse node from.

Returns:
ParseNode: A parse node.
Expand Down
9 changes: 4 additions & 5 deletions kiota_abstractions/serialization/serialization_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from datetime import date, datetime, time, timedelta
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -146,13 +145,13 @@ def write_collection_of_enum_values(
pass

@abstractmethod
def write_bytes_value(self, key: Optional[str], value: BytesIO) -> None:
def write_bytes_value(self, key: Optional[str], value: bytes) -> None:
baywet marked this conversation as resolved.
Show resolved Hide resolved
"""Writes the specified byte array as a base64 string to the stream with an optional
given key.

Args:
key (Optional[str]): The key to be used for the written value. May be null.
value (BytesIO): The byte array to be written.
value (bytes): The bytes to be written.
"""
pass

Expand Down Expand Up @@ -198,11 +197,11 @@ def write_additional_data_value(self, value: Dict[str, Any]) -> None:
pass

@abstractmethod
def get_serialized_content(self) -> BytesIO:
def get_serialized_content(self) -> bytes:
"""Gets the value of the serialized content.

Returns:
BytesIO: The value of the serialized content.
bytes: The value of the serialized content.
"""
pass

Expand Down
26 changes: 23 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock

import pytest

from kiota_abstractions.authentication.access_token_provider import AccessTokenProvider
from kiota_abstractions.authentication.allowed_hosts_validator import AllowedHostsValidator
from kiota_abstractions.multipart_body import MultipartBody
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.serialization import (
AdditionalDataHolder,
Parsable,
ParseNode,
SerializationWriter,
SerializationWriterFactory
)
from kiota_abstractions.store import BackedModel, BackingStore, BackingStoreFactorySingleton

Expand Down Expand Up @@ -111,6 +114,8 @@ class TestEnum(Enum):
VALUE2 = "value2"
VALUE3 = "value3"

__test__ = False

@dataclass
class QueryParams:
dataset: Union[TestEnum, List[TestEnum]]
Expand Down Expand Up @@ -138,6 +143,21 @@ def mock_access_token_provider():


@pytest.fixture
def mock_request_adapter(mocker):
mocker.patch.multiple(RequestAdapter, __abstractmethods__=set())
return RequestAdapter()
def mock_request_adapter():
request_adapter = Mock(spec=RequestAdapter)
return request_adapter

@pytest.fixture
def mock_serialization_writer():
return Mock(spec=SerializationWriter)


@pytest.fixture
def mock_serialization_writer_factory():
mock_factory = Mock(spec=SerializationWriterFactory)
return mock_factory

@pytest.fixture
def mock_multipart_body():
mock_multipart_body = MultipartBody()
return mock_multipart_body
Loading
Loading