diff --git a/setup.py b/setup.py index 59f2d572..ebf69d77 100755 --- a/setup.py +++ b/setup.py @@ -83,11 +83,14 @@ ], python_requires=">=3.9", install_requires=[ + "lz4", "python-dateutil", "pytz", # requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q "requests>=2.31.0", + "typing_extensions", "tzlocal", + "zstandard", ], extras_require={ "all": all_require, diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 0e347149..aea50d1a 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -29,12 +29,13 @@ from trino.transaction import IsolationLevel -@pytest.fixture -def trino_connection(run_trino): +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding ) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 7704e97b..1d48ddf3 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -12,12 +12,18 @@ from tests.integration.conftest import trino_version -@pytest.fixture -def trino_connection(run_trino): +@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"]) +def trino_connection(request, run_trino): host, port = run_trino + encoding = request.param yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 + host=host, + port=port, + user="test", + source="test", + max_attempts=1, + encoding=encoding ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 82e19a0f..1f84c6a7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -99,6 +99,7 @@ def test_request_headers(mock_get_and_post): accept_encoding_value = "identity,deflate,gzip" client_info_header = constants.HEADER_CLIENT_INFO client_info_value = "some_client_info" + encoding = "json+zstd" with pytest.deprecated_call(): req = TrinoRequest( @@ -111,6 +112,7 @@ def test_request_headers(mock_get_and_post): catalog=catalog, schema=schema, timezone=timezone, + encoding=encoding, headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, @@ -145,7 +147,8 @@ def assert_headers(headers): "catalog2=" + urllib.parse.quote("ROLE{catalog2_role}") ) assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}" - assert len(headers.keys()) == 13 + assert headers[constants.HEADER_ENCODING] == encoding + assert len(headers.keys()) == 14 req.post("URL") _, post_kwargs = post.call_args diff --git a/trino/client.py b/trino/client.py index 61db9444..6184c54c 100644 --- a/trino/client.py +++ b/trino/client.py @@ -34,8 +34,10 @@ """ from __future__ import annotations +import base64 import copy import functools +import json import os import random import re @@ -46,10 +48,13 @@ from datetime import datetime from email.utils import parsedate_to_datetime from time import sleep -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from zoneinfo import ZoneInfo +import lz4.block import requests +import zstandard +from typing_extensions import NotRequired, TypedDict from tzlocal import get_localzone_name # type: ignore import trino.logging @@ -107,6 +112,7 @@ class ClientSession: :param roles: roles for the current session. Some connectors do not support role management. See connector documentation for more details. :param timezone: The timezone for query processing. Defaults to the system's local timezone. + :param encoding: The encoding for the spooled protocol. Defaults to None. """ def __init__( @@ -123,6 +129,7 @@ def __init__( client_tags: List[str] = None, roles: Union[Dict[str, str], str] = None, timezone: str = None, + encoding: str = None, ): self._user = user self._authorization_user = authorization_user @@ -140,6 +147,7 @@ def __init__( self._timezone = timezone or get_localzone_name() if timezone: # Check timezone validity ZoneInfo(timezone) + self._encoding = encoding @property def user(self): @@ -235,6 +243,11 @@ def timezone(self): with self._object_lock: return self._timezone + @property + def encoding(self): + with self._object_lock: + return self._encoding + def _format_roles(self, roles): if isinstance(roles, str): roles = {"system": roles} @@ -462,6 +475,7 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone + headers[constants.HEADER_ENCODING] = self._client_session.encoding headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" if len(self._client_session.roles.values()): @@ -849,7 +863,14 @@ def fetch(self) -> List[List[Any]]: if not self._row_mapper: return [] - return self._row_mapper.map(status.rows) + rows = status.rows + if isinstance(rows, dict): + # spooled protocol + encoding = rows["encoding"] + segments = rows["segments"] + return list(SegmentIterator(segments, encoding, self._row_mapper, self._request)) + else: + return self._row_mapper.map(rows) def cancel(self) -> None: """Cancel the current query""" @@ -925,3 +946,86 @@ def _parse_retry_after_header(retry_after): retry_date = parsedate_to_datetime(retry_after) now = datetime.utcnow() return (retry_date - now).total_seconds() + + +# Trino Spooled protocol transfer objects +SpooledSegmentMetadata = TypedDict('SpooledSegmentMetadata', {'uncompressedSize': str}) +SpooledSegment = TypedDict( + 'SpooledSegment', + { + 'type': str, + 'uri': str, + 'ackUri': NotRequired[str], + 'data': List[List[Any]], + 'metadata': SpooledSegmentMetadata + } +) + + +class SegmentIterator: + def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: RowMapper, request: TrinoRequest): + self._segments = iter(segments) + self._encoding = encoding + self._row_mapper = row_mapper + self._request = request + self._rows: Iterator[List[List[Any]]] = iter([]) + self._finished = False + self._current_segment: Optional[SpooledSegment] = None + + def __iter__(self) -> Iterator[List[Any]]: + return self + + def __next__(self) -> List[Any]: + # If rows are exhausted, fetch the next segment + while True: + try: + return next(self._rows) + except StopIteration: + if self._current_segment and "ackUri" in self._current_segment: + ack_uri = self._current_segment["ackUri"] + http_response = self._request._get(ack_uri) + if not http_response.ok: + self._request.raise_response_error(http_response) + if self._finished: + raise StopIteration + self._load_next_row_set() + + def _load_next_row_set(self): + try: + self._current_segment = segment = next(self._segments) + segment_type = segment["type"] + + if segment_type == "inline": + data = segment["data"] + decoded_string = base64.b64decode(data) + rows = self._row_mapper.map(json.loads(decoded_string)) + self._rows = iter(rows) + + elif segment_type == "spooled": + decoded_string = self._load_spooled_segment(segment) + rows = self._row_mapper.map(json.loads(decoded_string)) + self._rows = iter(rows) + else: + raise ValueError(f"Unsupported segment type: {segment_type}") + + except StopIteration: + self._finished = True + + def _load_spooled_segment(self, segment: SpooledSegment) -> str: + uri = segment["uri"] + encoding = self._encoding + http_response = self._request._get(uri, stream=True) + if not http_response.ok: + self._request.raise_response_error(http_response) + + content = http_response.content + if encoding == "json+zstd": + zstd_decompressor = zstandard.ZstdDecompressor() + return zstd_decompressor.decompress(content).decode('utf-8') + elif encoding == "json+lz4": + expected_size = segment["metadata"]["uncompressedSize"] + return lz4.block.decompress(content, uncompressed_size=int(expected_size)).decode('utf-8') + elif encoding == "json": + return content.decode('utf-8') + else: + raise ValueError(f"Unsupported encoding: {encoding}") diff --git a/trino/constants.py b/trino/constants.py index 1dd0df94..ca29d138 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -37,6 +37,7 @@ HEADER_CLIENT_TAGS = "X-Trino-Client-Tags" HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential" HEADER_TIMEZONE = "X-Trino-Time-Zone" +HEADER_ENCODING = "X-Trino-Query-Data-Encoding" HEADER_SESSION = "X-Trino-Session" HEADER_SET_SESSION = "X-Trino-Set-Session" diff --git a/trino/dbapi.py b/trino/dbapi.py index 4f3dfdc6..b24ff090 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -153,6 +153,7 @@ def __init__( legacy_prepared_statements=None, roles=None, timezone=None, + encoding=None, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) @@ -176,6 +177,7 @@ def __init__( client_tags=client_tags, roles=roles, timezone=timezone, + encoding=encoding, ) # mypy cannot follow module import if http_session is None: