Skip to content

Commit

Permalink
Support spooled protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Nov 26, 2024
1 parent 0f4083d commit d7f4ee3
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 9 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@
],
python_requires=">=3.9",
install_requires=[
"lz4",
"python-dateutil",
"pytz",
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
"requests>=2.31.0",
"typing_extensions",
"tzlocal",
"zstandard",
],
extras_require={
"all": all_require,
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
from trino.transaction import IsolationLevel


@pytest.fixture
def trino_connection(run_trino):
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
def trino_connection(request, run_trino):
host, port = run_trino
encoding = request.param

yield trino.dbapi.Connection(
host=host, port=port, user="test", source="test", max_attempts=1
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
)


Expand Down
12 changes: 9 additions & 3 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
from tests.integration.conftest import trino_version


@pytest.fixture
def trino_connection(run_trino):
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
def trino_connection(request, run_trino):
host, port = run_trino
encoding = request.param

yield trino.dbapi.Connection(
host=host, port=port, user="test", source="test", max_attempts=1
host=host,
port=port,
user="test",
source="test",
max_attempts=1,
encoding=encoding
)


Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_request_headers(mock_get_and_post):
accept_encoding_value = "identity,deflate,gzip"
client_info_header = constants.HEADER_CLIENT_INFO
client_info_value = "some_client_info"
encoding = "json+zstd"

with pytest.deprecated_call():
req = TrinoRequest(
Expand All @@ -111,6 +112,7 @@ def test_request_headers(mock_get_and_post):
catalog=catalog,
schema=schema,
timezone=timezone,
encoding=encoding,
headers={
accept_encoding_header: accept_encoding_value,
client_info_header: client_info_value,
Expand Down Expand Up @@ -145,7 +147,8 @@ def assert_headers(headers):
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
)
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
assert len(headers.keys()) == 13
assert headers[constants.HEADER_ENCODING] == encoding
assert len(headers.keys()) == 14

req.post("URL")
_, post_kwargs = post.call_args
Expand Down
108 changes: 106 additions & 2 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
"""
from __future__ import annotations

import base64
import copy
import functools
import json
import os
import random
import re
Expand All @@ -46,10 +48,13 @@
from datetime import datetime
from email.utils import parsedate_to_datetime
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from zoneinfo import ZoneInfo

import lz4.block
import requests
import zstandard
from typing_extensions import NotRequired, TypedDict
from tzlocal import get_localzone_name # type: ignore

import trino.logging
Expand Down Expand Up @@ -107,6 +112,7 @@ class ClientSession:
:param roles: roles for the current session. Some connectors do not
support role management. See connector documentation for more details.
:param timezone: The timezone for query processing. Defaults to the system's local timezone.
:param encoding: The encoding for the spooled protocol. Defaults to None.
"""

def __init__(
Expand All @@ -123,6 +129,7 @@ def __init__(
client_tags: List[str] = None,
roles: Union[Dict[str, str], str] = None,
timezone: str = None,
encoding: str = None,
):
self._user = user
self._authorization_user = authorization_user
Expand All @@ -140,6 +147,7 @@ def __init__(
self._timezone = timezone or get_localzone_name()
if timezone: # Check timezone validity
ZoneInfo(timezone)
self._encoding = encoding

@property
def user(self):
Expand Down Expand Up @@ -235,6 +243,11 @@ def timezone(self):
with self._object_lock:
return self._timezone

@property
def encoding(self):
with self._object_lock:
return self._encoding

def _format_roles(self, roles):
if isinstance(roles, str):
roles = {"system": roles}
Expand Down Expand Up @@ -462,6 +475,7 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
headers[constants.HEADER_ENCODING] = self._client_session.encoding
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
if len(self._client_session.roles.values()):
Expand Down Expand Up @@ -849,7 +863,14 @@ def fetch(self) -> List[List[Any]]:
if not self._row_mapper:
return []

return self._row_mapper.map(status.rows)
rows = status.rows
if isinstance(rows, dict):
# spooled protocol
encoding = rows["encoding"]
segments = rows["segments"]
return list(SegmentIterator(segments, encoding, self._row_mapper, self._request))
else:
return self._row_mapper.map(rows)

def cancel(self) -> None:
"""Cancel the current query"""
Expand Down Expand Up @@ -925,3 +946,86 @@ def _parse_retry_after_header(retry_after):
retry_date = parsedate_to_datetime(retry_after)
now = datetime.utcnow()
return (retry_date - now).total_seconds()


# Trino Spooled protocol transfer objects
SpooledSegmentMetadata = TypedDict('SpooledSegmentMetadata', {'uncompressedSize': str})
SpooledSegment = TypedDict(
'SpooledSegment',
{
'type': str,
'uri': str,
'ackUri': NotRequired[str],
'data': List[List[Any]],
'metadata': SpooledSegmentMetadata
}
)


class SegmentIterator:
def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: RowMapper, request: TrinoRequest):
self._segments = iter(segments)
self._encoding = encoding
self._row_mapper = row_mapper
self._request = request
self._rows: Iterator[List[List[Any]]] = iter([])
self._finished = False
self._current_segment: Optional[SpooledSegment] = None

def __iter__(self) -> Iterator[List[Any]]:
return self

def __next__(self) -> List[Any]:
# If rows are exhausted, fetch the next segment
while True:
try:
return next(self._rows)
except StopIteration:
if self._current_segment and "ackUri" in self._current_segment:
ack_uri = self._current_segment["ackUri"]
http_response = self._request._get(ack_uri)
if not http_response.ok:
self._request.raise_response_error(http_response)
if self._finished:
raise StopIteration
self._load_next_row_set()

def _load_next_row_set(self):
try:
self._current_segment = segment = next(self._segments)
segment_type = segment["type"]

if segment_type == "inline":
data = segment["data"]
decoded_string = base64.b64decode(data)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)

elif segment_type == "spooled":
decoded_string = self._load_spooled_segment(segment)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)
else:
raise ValueError(f"Unsupported segment type: {segment_type}")

except StopIteration:
self._finished = True

def _load_spooled_segment(self, segment: SpooledSegment) -> str:
uri = segment["uri"]
encoding = self._encoding
http_response = self._request._get(uri, stream=True)
if not http_response.ok:
self._request.raise_response_error(http_response)

content = http_response.content
if encoding == "json+zstd":
zstd_decompressor = zstandard.ZstdDecompressor()
return zstd_decompressor.decompress(content).decode('utf-8')
elif encoding == "json+lz4":
expected_size = segment["metadata"]["uncompressedSize"]
return lz4.block.decompress(content, uncompressed_size=int(expected_size)).decode('utf-8')
elif encoding == "json":
return content.decode('utf-8')
else:
raise ValueError(f"Unsupported encoding: {encoding}")
1 change: 1 addition & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
HEADER_TIMEZONE = "X-Trino-Time-Zone"
HEADER_ENCODING = "X-Trino-Query-Data-Encoding"

HEADER_SESSION = "X-Trino-Session"
HEADER_SET_SESSION = "X-Trino-Set-Session"
Expand Down
2 changes: 2 additions & 0 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
legacy_prepared_statements=None,
roles=None,
timezone=None,
encoding=None,
):
# Automatically assign http_schema, port based on hostname
parsed_host = urlparse(host, allow_fragments=False)
Expand All @@ -176,6 +177,7 @@ def __init__(
client_tags=client_tags,
roles=roles,
timezone=timezone,
encoding=encoding,
)
# mypy cannot follow module import
if http_session is None:
Expand Down

0 comments on commit d7f4ee3

Please sign in to comment.