Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ydb provider: add bulk upsert support #40631

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions airflow/providers/ydb/hooks/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,8 @@ def description(self):
class YDBConnection:
"""YDB connection wrapper."""

def __init__(self, endpoint: str, database: str, credentials: Any, is_ddl: bool = False):
def __init__(self, ydb_session_pool: Any, is_ddl: bool):
self.is_ddl = is_ddl
driver_config = ydb.DriverConfig(
endpoint=endpoint,
database=database,
table_client_settings=YDBConnection._get_table_client_settings(),
credentials=credentials,
)
driver = ydb.Driver(driver_config)
# wait until driver become initialized
driver.wait(fail_fast=True, timeout=10)
ydb_session_pool = ydb.SessionPool(driver, size=5)
self.delegatee: DbApiConnection = DbApiConnection(ydb_session_pool=ydb_session_pool)

def cursor(self) -> YDBCursor:
Expand All @@ -130,16 +120,8 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def close(self) -> None:
self.delegatee.close()

@staticmethod
def _get_table_client_settings() -> ydb.TableClientSettings:
return (
ydb.TableClientSettings()
.with_native_date_in_result_sets(True)
.with_native_datetime_in_result_sets(True)
.with_native_timestamp_in_result_sets(True)
.with_native_interval_in_result_sets(True)
.with_native_json_in_result_sets(False)
)
def bulk_upsert(self, table_name: str, rows: Sequence, column_types: ydb.BulkUpsertColumns):
self.delegatee.driver.table_client.bulk_upsert(table_name, rows=rows, column_types=column_types)


class YDBHook(DbApiHook):
Expand All @@ -156,6 +138,33 @@ def __init__(self, *args, is_ddl: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.is_ddl = is_ddl

conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
host: str | None = conn.host
if not host:
raise ValueError("YDB host must be specified")
port: int = conn.port or DEFAULT_YDB_GRPCS_PORT

connection_extra: dict[str, Any] = conn.extra_dejson
database: str | None = connection_extra.get("database")
if not database:
raise ValueError("YDB database must be specified")

endpoint = f"{host}:{port}"
credentials = get_credentials_from_connection(
endpoint=endpoint, database=database, connection=conn, connection_extra=connection_extra
)

driver_config = ydb.DriverConfig(
endpoint=endpoint,
database=database,
table_client_settings=YDBHook._get_table_client_settings(),
credentials=credentials,
)
driver = ydb.Driver(driver_config)
# wait until driver become initialized
driver.wait(fail_fast=True, timeout=10)
self.ydb_session_pool = ydb.SessionPool(driver, size=5)

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to YDB connection form."""
Expand Down Expand Up @@ -226,26 +235,29 @@ def sqlalchemy_url(self) -> URL:

def get_conn(self) -> YDBConnection:
"""Establish a connection to a YDB database."""
conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
host: str | None = conn.host
if not host:
raise ValueError("YDB host must be specified")
port: int = conn.port or DEFAULT_YDB_GRPCS_PORT
return YDBConnection(self.ydb_session_pool, is_ddl=self.is_ddl)

connection_extra: dict[str, Any] = conn.extra_dejson
database: str | None = connection_extra.get("database")
if not database:
raise ValueError("YDB database must be specified")
@staticmethod
def _serialize_cell(cell: object, conn: YDBConnection | None = None) -> Any:
return cell

endpoint = f"{host}:{port}"
credentials = get_credentials_from_connection(
endpoint=endpoint, database=database, connection=conn, connection_extra=connection_extra
)
def bulk_upsert(self, table_name: str, rows: Sequence, column_types: ydb.BulkUpsertColumns):
"""
BulkUpsert into database. More optimal way to insert rows into db.

return YDBConnection(
endpoint=endpoint, database=database, credentials=credentials, is_ddl=self.is_ddl
)
.. seealso::

https://ydb.tech/docs/en/recipes/ydb-sdk/bulk-upsert
"""
self.get_conn().bulk_upsert(table_name, rows, column_types)

@staticmethod
def _serialize_cell(cell: object, conn: YDBConnection | None = None) -> Any:
return cell
def _get_table_client_settings() -> ydb.TableClientSettings:
return (
ydb.TableClientSettings()
.with_native_date_in_result_sets(True)
.with_native_datetime_in_result_sets(True)
.with_native_timestamp_in_result_sets(True)
.with_native_interval_in_result_sets(True)
.with_native_json_in_result_sets(False)
)
36 changes: 36 additions & 0 deletions tests/integration/providers/ydb/operators/test_ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import MagicMock

import pytest
import ydb

from airflow.models.connection import Connection
from airflow.models.dag import DAG
Expand Down Expand Up @@ -58,3 +59,38 @@ def test_execute_hello(self):

results = operator.execute(self.mock_context)
assert results == [(987,)]

def test_bulk_upsert(self):
create_table_op = YDBExecuteQueryOperator(
task_id="create",
sql="""
CREATE TABLE team (
id INT,
name TEXT,
age UINT32,
PRIMARY KEY (id)
);""",
is_ddl=True,
)

create_table_op.execute(self.mock_context)

age_sum_op = YDBExecuteQueryOperator(task_id="age_sum", sql="SELECT SUM(age) as age_sum FROM team")

hook = age_sum_op.get_db_hook()
column_types = (
ydb.BulkUpsertColumns()
.add_column("id", ydb.OptionalType(ydb.PrimitiveType.Int32))
.add_column("name", ydb.OptionalType(ydb.PrimitiveType.Utf8))
.add_column("age", ydb.OptionalType(ydb.PrimitiveType.Uint32))
)

rows = [
{"id": 1, "name": "rabbits", "age": 17},
{"id": 2, "name": "bears", "age": 22},
{"id": 3, "name": "foxes", "age": 9},
]
hook.bulk_upsert("/local/team", rows=rows, column_types=column_types)

result = age_sum_op.execute(self.mock_context)
assert result == [(48,)]
46 changes: 42 additions & 4 deletions tests/providers/ydb/operators/test_ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
import ydb

from airflow.models import Connection
from airflow.models.dag import DAG
Expand Down Expand Up @@ -54,6 +55,15 @@ def __init__(self, driver):
self._driver = driver


class FakeTableClient:
def __init__(self, *args):
self.bulk_upsert_args = []

def bulk_upsert(self, table_path, rows, column_types, settings=None):
assert settings is None
self.bulk_upsert_args.append((table_path, rows, column_types))


class FakeSessionPool:
def __init__(self, driver):
self._pool_impl = FakeSessionPoolImpl(driver)
Expand Down Expand Up @@ -96,19 +106,29 @@ def setup_method(self):
@patch("airflow.hooks.base.BaseHook.get_connection")
@patch("ydb.Driver")
@patch("ydb.SessionPool")
@patch(
"airflow.providers.ydb.hooks._vendor.dbapi.connection.Connection._ydb_table_client_class",
new_callable=PropertyMock,
)
@patch(
"airflow.providers.ydb.hooks._vendor.dbapi.connection.Connection._cursor_class",
new_callable=PropertyMock,
)
def test_execute_query(self, cursor_class, mock_session_pool, mock_driver, mock_get_connection):
def test_execute_query(
self, cursor_class, table_client_class, mock_session_pool, mock_driver, mock_get_connection
):
mock_get_connection.return_value = Connection(
conn_type="ydb", host="localhost", extra={"database": "my_db"}
)
driver_instance = FakeDriver()

cursor_class.return_value = FakeYDBCursor
mock_driver.return_value = driver_instance
mock_session_pool.return_value = FakeSessionPool(driver_instance)
table_client_class.return_value = FakeTableClient

driver = FakeDriver()
mock_driver.return_value = driver

session_pool = FakeSessionPool(driver)
mock_session_pool.return_value = session_pool
context = {"ti": MagicMock()}
operator = YDBExecuteQueryOperator(
task_id="simple_sql", sql="select 987", is_ddl=False, handler=fetch_one_handler
Expand All @@ -123,3 +143,21 @@ def test_execute_query(self, cursor_class, mock_session_pool, mock_driver, mock_

results = operator.execute(context)
assert results == "fetchall: result"

hook = operator.get_db_hook()

column_types = (
ydb.BulkUpsertColumns()
.add_column("a", ydb.OptionalType(ydb.PrimitiveType.Uint64))
.add_column("b", ydb.OptionalType(ydb.PrimitiveType.Utf8))
)

rows = [
{"a": 1, "b": "hello"},
{"a": 888, "b": "world"},
]
hook.bulk_upsert("/root/my_table", rows=rows, column_types=column_types)
assert len(session_pool._pool_impl._driver.table_client.bulk_upsert_args) == 1
arg0 = session_pool._pool_impl._driver.table_client.bulk_upsert_args[0]
assert arg0[0] == "/root/my_table"
assert len(arg0[1]) == 2