Skip to content

Commit

Permalink
Support Athena parameterized queries when paramstyle is qmark (fix #545)
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Aug 17, 2024
1 parent 1235a6d commit d04b1a0
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pyathena/arrow/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]:
if self._unload:
Expand All @@ -125,6 +126,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return (
query_id,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> ArrowCursor:
self._reset_state()
Expand All @@ -129,6 +130,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down
2 changes: 2 additions & 0 deletions pyathena/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]:
query_id = self._execute(
Expand All @@ -115,6 +116,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return query_id, self._executor.submit(self._collect_result_set, query_id)

Expand Down
15 changes: 13 additions & 2 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import pyathena
from pyathena.converter import Converter, DefaultTypeConverter
from pyathena.error import DatabaseError, OperationalError, ProgrammingError
from pyathena.formatter import Formatter
Expand Down Expand Up @@ -144,6 +145,7 @@ def _build_start_query_execution_request(
s3_staging_dir: Optional[str] = None,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
execution_parameters: Optional[List[str]] = None,
) -> Dict[str, Any]:
request: Dict[str, Any] = {
"QueryString": query,
Expand Down Expand Up @@ -177,6 +179,8 @@ def _build_start_query_execution_request(
else self._result_reuse_minutes,
}
request["ResultReuseConfiguration"] = {"ResultReuseByAgeConfiguration": reuse_conf}
if execution_parameters:
request["ExecutionParameters"] = execution_parameters
return request

def _build_start_calculation_execution_request(
Expand Down Expand Up @@ -546,15 +550,21 @@ def _find_previous_query_id(
def _execute(
self,
operation: str,
parameters: Optional[Dict[str, Any]] = None,
parameters: Optional[Union[Dict[str, Any], List[str]]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
) -> str:
query = self._formatter.format(operation, parameters)
if pyathena.paramstyle == "qmark" or paramstyle == "qmark":
query = operation
execution_parameters = cast(Optional[List[str]], parameters)
else:
query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters))
execution_parameters = None
_logger.debug(query)

request = self._build_start_query_execution_request(
Expand All @@ -563,6 +573,7 @@ def _execute(
s3_staging_dir=s3_staging_dir,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
execution_parameters=execution_parameters,
)
query_id = self._find_previous_query_id(
query,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def execute(
cache_expiration_time: int = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
**kwargs,
) -> Cursor:
self._reset_state()
Expand All @@ -94,6 +95,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down
2 changes: 2 additions & 0 deletions pyathena/pandas/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
keep_default_na: bool = False,
na_values: Optional[Iterable[str]] = ("",),
quoting: int = 1,
Expand All @@ -138,6 +139,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
return (
query_id,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/pandas/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def execute(
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
keep_default_na: bool = False,
na_values: Optional[Iterable[str]] = ("",),
quoting: int = 1,
Expand All @@ -154,6 +155,7 @@ def execute(
cache_expiration_time=cache_expiration_time,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
paramstyle=paramstyle,
)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
Expand Down

0 comments on commit d04b1a0

Please sign in to comment.