Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: trino cursor #25897

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import logging
import threading
import time
from typing import Any, TYPE_CHECKING

import simplejson as json
Expand Down Expand Up @@ -152,7 +151,13 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None:
return None

@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
def handle_cursor_with_query_id(
cls,
cursor: Cursor,
query: Query,
session: Session,
cancel_query_id: str,
) -> None:
"""
Handle a trino client cursor.

Expand All @@ -162,7 +167,6 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
"""

# Adds the executed query id to the extra payload so the query can be cancelled
cancel_query_id = cursor.query_id
logger.debug("Query %d: queryId %s found in cursor", query.id, cancel_query_id)
query.set_extra_json_key(key=QUERY_CANCEL_KEY, value=cancel_query_id)

Expand All @@ -180,11 +184,15 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
cancel_query_id=cancel_query_id,
)

super().handle_cursor(cursor=cursor, query=query, session=session)
cls.handle_cursor(cursor=cursor, query=query, session=session)

@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
cls,
cursor: Cursor,
sql: str,
query: Query,
session: Session,
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
Expand All @@ -193,34 +201,32 @@ def execute_with_cursor(
in another thread and invoke `handle_cursor` to poll for the query ID
to appear on the cursor in parallel.
"""
# Fetch the query ID before hand, since it might fail inside the thread due to
# how the SQLAlchemy session is handled.
query_id = query.id
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

execute_result: dict[str, Any] = {}
execute_event = threading.Event()

def _execute(results: dict[str, Any]) -> None:
logger.debug("Query %d: Running query: %s", query.id, sql)
def _execute(results: dict[str, Any], event: threading.Event) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)

# Pass result / exception information back to the parent thread
try:
cls.execute(cursor, sql)
results["complete"] = True
except Exception as ex: # pylint: disable=broad-except
results["complete"] = True
results["error"] = ex
finally:
event.set()

execute_thread = threading.Thread(target=_execute, args=(execute_result,))
execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
)
execute_thread.start()
execute_event.wait()
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

# Wait for a query ID to be available before handling the cursor, as
# it's required by that method; it may never become available on error.
while not cursor.query_id and not execute_result.get("complete"):
time.sleep(0.1)

logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)

# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query.id)
while not execute_result.get("complete"):
time.sleep(0.5)
logger.debug("Query %d: Handling cursor", query_id)
cls.handle_cursor_with_query_id(cursor, query, session, query_id)
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

# Unfortunately we'll mangle the stack trace due to the thread, but
# throwing the original exception allows mapping database errors as normal
Expand All @@ -234,7 +240,7 @@ def prepare_cancel_query(cls, query: Query, session: Session) -> None:
session.commit()

@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool:
"""
Cancel query in the underlying database.

Expand Down
11 changes: 9 additions & 2 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ def test_handle_cursor_early_cancel(
if cancel_early:
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)

TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
TrinoEngineSpec.handle_cursor_with_query_id(
cursor=cursor_mock,
query=query,
session=session_mock,
cancel_query_id=query_id,
)

if cancel_early:
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
Expand All @@ -378,6 +383,7 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
mock_query.id = query_id
mock_session = mocker.MagicMock()

def _mock_execute(*args, **kwargs):
Expand All @@ -393,5 +399,6 @@ def _mock_execute(*args, **kwargs):
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
key=QUERY_CANCEL_KEY,
value=query_id,
)