Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: fix _sql_record_manager mypy for #17048 #17073

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 54 additions & 58 deletions libs/langchain/langchain/indexes/_sql_record_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Query, Session, sessionmaker

from langchain.indexes.base import RecordManager

Expand Down Expand Up @@ -284,31 +284,35 @@ def update(

with self._make_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert

# Note: uses SQLite insert to make on_conflict_do_update work.
# Note: uses postgresql insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
)
else:
Expand Down Expand Up @@ -359,31 +363,35 @@ async def aupdate(

async with self._amake_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert

# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
)
else:
Expand All @@ -394,18 +402,15 @@ async def aupdate(

def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
session: Session
with self._make_session() as session:
records = (
# mypy does not recognize .all()
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
.filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
filtered_query: Query = session.query(UpsertionRecord.key).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
.all()
)
records = filtered_query.all()
found_keys = set(r.key for r in records)
return [k in found_keys for k in keys]

Expand Down Expand Up @@ -438,28 +443,22 @@ def list_keys(
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
session: Session
with self._make_session() as session:
query = session.query(UpsertionRecord).filter(
query: Query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace
)

# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
query = query.filter(UpsertionRecord.group_id.in_(group_ids))

if limit:
query = query.limit(limit) # type: ignore[attr-defined]
records = query.all() # type: ignore[attr-defined]
query = query.limit(limit)
records = query.all()
return [r.key for r in records]

async def alist_keys(
Expand All @@ -471,40 +470,37 @@ async def alist_keys(
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
session: AsyncSession
async with self._amake_session() as session:
query = select(UpsertionRecord.key).filter(
query: Query = select(UpsertionRecord.key).filter(
UpsertionRecord.namespace == self.namespace
)

# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
query = query.filter(UpsertionRecord.group_id.in_(group_ids))

if limit:
query = query.limit(limit) # type: ignore[attr-defined]
query = query.limit(limit)
records = (await session.execute(query)).scalars().all()
return list(records)

def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
session: Session
with self._make_session() as session:
# mypy does not recognize .delete()
session.query(UpsertionRecord).filter(
filtered_query: Query = session.query(UpsertionRecord).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
).delete() # type: ignore[attr-defined]
)

filtered_query.delete()
session.commit()

async def adelete_keys(self, keys: Sequence[str]) -> None:
Expand Down
Loading