Skip to content

Commit

Permalink
Add bulk upsert to connection
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Nov 7, 2024
1 parent 7b4fac3 commit c61e240
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
55 changes: 55 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,52 @@ def _test_errors(
maybe_await(cur.execute_scheme("DROP TABLE test"))
maybe_await(cur.close())

def _test_bulk_upsert(self, connection: dbapi.Connection):
cursor = connection.cursor()
maybe_await(cursor.execute_scheme(
"""
CREATE TABLE pet (
pet_id INT,
name TEXT NOT NULL,
pet_type TEXT NOT NULL,
birth_date TEXT NOT NULL,
owner TEXT NOT NULL,
PRIMARY KEY (pet_id)
);
"""
))

column_types = (
ydb.BulkUpsertColumns()
.add_column("pet_id", ydb.OptionalType(ydb.PrimitiveType.Int32))
.add_column("name", ydb.PrimitiveType.Utf8)
.add_column("pet_type", ydb.PrimitiveType.Utf8)
.add_column("birth_date", ydb.PrimitiveType.Utf8)
.add_column("owner", ydb.PrimitiveType.Utf8)
)

rows = [
{
"pet_id": 3,
"name": "Lester",
"pet_type": "Hamster",
"birth_date": "2020-06-23",
"owner": "Lily"
},
{
"pet_id": 4,
"name": "Quincy",
"pet_type": "Parrot",
"birth_date": "2013-08-11",
"owner": "Anne"
},
]

maybe_await(connection.bulk_upsert("pet", rows, column_types))

maybe_await(cursor.execute("SELECT * FROM pet"))
assert cursor.rowcount == 2


class TestConnection(BaseDBApiTestSuit):
@pytest.fixture
Expand Down Expand Up @@ -191,6 +237,9 @@ def test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
def test_errors(self, connection: dbapi.Connection) -> None:
self._test_errors(connection)

def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
self._test_bulk_upsert(connection)


class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
Expand Down Expand Up @@ -244,3 +293,9 @@ async def test_cursor_raw_query(
@pytest.mark.asyncio
async def test_errors(self, connection: dbapi.AsyncConnection) -> None:
await greenlet_spawn(self._test_errors, connection)

@pytest.mark.asyncio
async def test_bulk_upsert(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_bulk_upsert, connection)
23 changes: 23 additions & 0 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import posixpath
from collections.abc import Sequence
from enum import Enum
from typing import NamedTuple

Expand Down Expand Up @@ -301,6 +302,17 @@ def callee() -> ydb.Directory:
result.extend(self._get_table_names(child_abs_path))
return result

@handle_ydb_errors
def bulk_upsert(
self,
table_name: str,
rows: Sequence,
column_types: ydb.BulkUpsertColumns,
) -> None:
self._driver.table_client.bulk_upsert(
table_name, rows=rows, column_types=column_types
)


class AsyncConnection(BaseConnection):
_driver_cls = ydb.aio.Driver
Expand Down Expand Up @@ -446,6 +458,17 @@ async def callee() -> ydb.Directory:
result.extend(await self._get_table_names(child_abs_path))
return result

@handle_ydb_errors
async def bulk_upsert(
self,
table_name: str,
rows: Sequence,
column_types: ydb.BulkUpsertColumns,
) -> None:
await self._driver.table_client.bulk_upsert(
table_name, rows=rows, column_types=column_types
)


def connect(*args: tuple, **kwargs: dict) -> Connection:
conn = Connection(*args, **kwargs) # type: ignore
Expand Down

0 comments on commit c61e240

Please sign in to comment.