Skip to content

Commit

Permalink
SNOW-1325701: Use scoped temp object in write pandas (#2068)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang authored Oct 17, 2024
1 parent aeb771c commit a959fc6
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 16 deletions.
45 changes: 45 additions & 0 deletions src/snowflake/connector/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import string
from enum import Enum
from random import choice


class TempObjectType(Enum):
TABLE = "TABLE"
VIEW = "VIEW"
STAGE = "STAGE"
FUNCTION = "FUNCTION"
FILE_FORMAT = "FILE_FORMAT"
QUERY_TAG = "QUERY_TAG"
COLUMN = "COLUMN"
PROCEDURE = "PROCEDURE"
TABLE_FUNCTION = "TABLE_FUNCTION"
DYNAMIC_TABLE = "DYNAMIC_TABLE"
AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION"
CTE = "CTE"


TEMP_OBJECT_NAME_PREFIX = "SNOWPARK_TEMP_"
ALPHANUMERIC = string.digits + string.ascii_lowercase
TEMPORARY_STRING = "TEMP"
SCOPED_TEMPORARY_STRING = "SCOPED TEMPORARY"
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = (
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS"
)


def generate_random_alphanumeric(length: int = 10) -> str:
return "".join(choice(ALPHANUMERIC) for _ in range(length))


def random_name_for_temp_object(object_type: TempObjectType) -> str:
return f"{TEMP_OBJECT_NAME_PREFIX}{object_type.value}_{generate_random_alphanumeric().upper()}"


def get_temp_type_for_object(use_scoped_temp_objects: bool) -> str:
return SCOPED_TEMPORARY_STRING if use_scoped_temp_objects else TEMPORARY_STRING
25 changes: 19 additions & 6 deletions src/snowflake/connector/bind_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from logging import getLogger
from typing import TYPE_CHECKING

from ._utils import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
get_temp_type_for_object,
)
from .errors import BindUploadError, Error

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -19,11 +23,6 @@


class BindUploadAgent:
_STAGE_NAME = "SYSTEMBIND"
_CREATE_STAGE_STMT = (
f"create or replace temporary stage {_STAGE_NAME} "
"file_format=(type=csv field_optionally_enclosed_by='\"')"
)

def __init__(
self,
Expand All @@ -38,13 +37,27 @@ def __init__(
rows: Rows of binding parameters in CSV format.
stream_buffer_size: Size of each file, default to 10MB.
"""
self._use_scoped_temp_object = (
cursor.connection._session_parameters.get(
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, False
)
if cursor.connection._session_parameters
else False
)
self._STAGE_NAME = (
"SNOWPARK_TEMP_STAGE_BIND" if self._use_scoped_temp_object else "SYSTEMBIND"
)
self.cursor = cursor
self.rows = rows
self._stream_buffer_size = stream_buffer_size
self.stage_path = f"@{self._STAGE_NAME}/{uuid.uuid4().hex}"

def _create_stage(self) -> None:
self.cursor.execute(self._CREATE_STAGE_STMT)
create_stage_sql = (
f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} "
"file_format=(type=csv field_optionally_enclosed_by='\"')"
)
self.cursor.execute(create_stage_sql)

def upload(self) -> None:
try:
Expand Down
62 changes: 54 additions & 8 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
from snowflake.connector.telemetry import TelemetryData, TelemetryField
from snowflake.connector.util_text import random_string

from ._utils import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
TempObjectType,
get_temp_type_for_object,
random_name_for_temp_object,
)
from .cursor import SnowflakeCursor

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -77,8 +83,9 @@ def _do_create_temp_stage(
compression: str,
auto_create_table: bool,
overwrite: bool,
use_scoped_temp_object: bool,
) -> None:
create_stage_sql = f"CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
create_stage_sql = f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
logger.debug(f"creating stage with '{create_stage_sql}'")
cursor.execute(create_stage_sql, _is_internal=True).fetchall()

Expand All @@ -91,8 +98,13 @@ def _create_temp_stage(
compression: str,
auto_create_table: bool,
overwrite: bool,
use_scoped_temp_object: bool = False,
) -> str:
stage_name = random_string()
stage_name = (
random_name_for_temp_object(TempObjectType.STAGE)
if use_scoped_temp_object
else random_string()
)
stage_location = build_location_helper(
database=database,
schema=schema,
Expand All @@ -101,7 +113,12 @@ def _create_temp_stage(
)
try:
_do_create_temp_stage(
cursor, stage_location, compression, auto_create_table, overwrite
cursor,
stage_location,
compression,
auto_create_table,
overwrite,
use_scoped_temp_object,
)
except ProgrammingError as e:
# User may not have the privilege to create stage on the target schema, so fall back to use current schema as
Expand All @@ -111,7 +128,12 @@ def _create_temp_stage(
)
stage_location = stage_name
_do_create_temp_stage(
cursor, stage_location, compression, auto_create_table, overwrite
cursor,
stage_location,
compression,
auto_create_table,
overwrite,
use_scoped_temp_object,
)

return stage_location
Expand All @@ -122,9 +144,10 @@ def _do_create_temp_file_format(
file_format_location: str,
compression: str,
sql_use_logical_type: str,
use_scoped_temp_object: bool,
) -> None:
file_format_sql = (
f"CREATE TEMP FILE FORMAT {file_format_location} "
f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} FILE FORMAT {file_format_location} "
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"TYPE=PARQUET COMPRESSION={compression}{sql_use_logical_type}"
)
Expand All @@ -139,8 +162,13 @@ def _create_temp_file_format(
quote_identifiers: bool,
compression: str,
sql_use_logical_type: str,
use_scoped_temp_object: bool = False,
) -> str:
file_format_name = random_string()
file_format_name = (
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
if use_scoped_temp_object
else random_string()
)
file_format_location = build_location_helper(
database=database,
schema=schema,
Expand All @@ -149,7 +177,11 @@ def _create_temp_file_format(
)
try:
_do_create_temp_file_format(
cursor, file_format_location, compression, sql_use_logical_type
cursor,
file_format_location,
compression,
sql_use_logical_type,
use_scoped_temp_object,
)
except ProgrammingError as e:
# User may not have the privilege to create file format on the target schema, so fall back to use current schema
Expand All @@ -159,7 +191,11 @@ def _create_temp_file_format(
)
file_format_location = file_format_name
_do_create_temp_file_format(
cursor, file_format_location, compression, sql_use_logical_type
cursor,
file_format_location,
compression,
sql_use_logical_type,
use_scoped_temp_object,
)

return file_format_location
Expand Down Expand Up @@ -263,6 +299,14 @@ def write_pandas(
f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}"
)

_use_scoped_temp_object = (
conn._session_parameters.get(
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, False
)
if conn._session_parameters
else False
)

if create_temp_table:
warnings.warn(
"create_temp_table is deprecated, we still respect this parameter when it is True but "
Expand Down Expand Up @@ -324,6 +368,7 @@ def write_pandas(
compression,
auto_create_table,
overwrite,
_use_scoped_temp_object,
)

with TemporaryDirectory() as tmp_folder:
Expand Down Expand Up @@ -370,6 +415,7 @@ def drop_object(name: str, object_type: str) -> None:
quote_identifiers,
compression_map[compression],
sql_use_logical_type,
_use_scoped_temp_object,
)
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@{stage_location}', file_format=>'{file_format_location}'))"
logger.debug(f"inferring schema with '{infer_schema_sql}'")
Expand Down
56 changes: 54 additions & 2 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,59 @@ def mocked_execute(*args, **kwargs):
)
assert m_execute.called and any(
map(
lambda e: "CREATE TEMP STAGE" in str(e[0]),
lambda e: ("CREATE TEMP STAGE" in str(e[0])),
m_execute.call_args_list,
)
)


@pytest.mark.parametrize(
"database,schema,quote_identifiers,expected_db_schema",
[
("database", "schema", True, '"database"."schema"'),
("database", "schema", False, "database.schema"),
(None, "schema", True, '"schema"'),
(None, "schema", False, "schema"),
(None, None, True, ""),
(None, None, False, ""),
],
)
def test_use_scoped_object(
conn_cnx,
database: str | None,
schema: str | None,
quote_identifiers: bool,
expected_db_schema: str,
):
"""This tests that write_pandas constructs stage location correctly with database and schema."""
from snowflake.connector.cursor import SnowflakeCursor

with conn_cnx() as cnx:

def mocked_execute(*args, **kwargs):
if len(args) >= 1 and args[0].startswith("create temporary stage"):
db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1])
assert db_schema == expected_db_schema
cur = SnowflakeCursor(cnx)
cur._result = iter([])
return cur

with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mocked_execute,
) as m_execute:
cnx._update_parameters({"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": True})
success, nchunks, nrows, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
"table",
database=database,
schema=schema,
quote_identifiers=quote_identifiers,
)
assert m_execute.called and any(
map(
lambda e: ("CREATE SCOPED TEMPORARY STAGE" in str(e[0])),
m_execute.call_args_list,
)
)
Expand Down Expand Up @@ -660,7 +712,7 @@ def mocked_execute(*args, **kwargs):
)
assert m_execute.called and any(
map(
lambda e: "CREATE TEMP FILE FORMAT" in str(e[0]),
lambda e: ("CREATE TEMP FILE FORMAT" in str(e[0])),
m_execute.call_args_list,
)
)
Expand Down
16 changes: 16 additions & 0 deletions test/unit/test_bind_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock


Expand All @@ -26,3 +27,18 @@ def test_bind_upload_agent_row_size_exceed_buffer_size():
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
agent.upload()
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files


def test_bind_upload_agent_scoped_temp_object():
from snowflake.connector.bind_upload_agent import BindUploadAgent

csr = MagicMock(auto_spec=True)
rows = [bytes(15)] * 10
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
with mock.patch.object(agent, "_use_scoped_temp_object", new=True):
with mock.patch.object(agent.cursor, "execute") as mock_execute:
agent._create_stage()
assert (
"create or replace SCOPED TEMPORARY stage"
in mock_execute.call_args[0][0]
)

0 comments on commit a959fc6

Please sign in to comment.