Skip to content

Commit

Permalink
tweak db type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Mar 20, 2020
1 parent b5c6134 commit 77cf7dd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
24 changes: 17 additions & 7 deletions pylib/anki/dbproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
# fixme: progress

from sqlite3 import dbapi2 as sqlite
from typing import Any, Iterable, List, Optional
from typing import Any, Iterable, List, Optional, Sequence, Union

# DBValue is actually Union[str, int, float, None], but if defined
# that way, every call site needs to do a type check prior to using
# the return values.
ValueFromDB = Any
Row = Sequence[ValueFromDB]

ValueForDB = Union[str, int, float, None]


class DBProxy:
Expand Down Expand Up @@ -38,7 +46,9 @@ def setAutocommit(self, autocommit: bool) -> None:
# Querying
################

def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]:
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
for stmt in "insert", "update", "delete":
Expand All @@ -59,20 +69,20 @@ def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]:
# Query shortcuts
###################

def all(self, sql: str, *args) -> List:
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args)

def list(self, sql: str, *args) -> List:
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)]

def first(self, sql: str, *args) -> Optional[List]:
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0]
else:
return None

def scalar(self, sql: str, *args) -> Optional[Any]:
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0][0]
Expand All @@ -86,7 +96,7 @@ def scalar(self, sql: str, *args) -> Optional[Any]:
# Updates
################

def executemany(self, sql: str, args: Iterable) -> None:
def executemany(self, sql: str, args: Iterable[Iterable[ValueForDB]]) -> None:
self.mod = True
self._db.executemany(sql, args)

Expand Down
6 changes: 3 additions & 3 deletions pylib/anki/schedv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def counts(self, card: Optional[Card] = None) -> Tuple[int, int, int]:

def dueForecast(self, days: int = 7) -> List[Any]:
"Return counts over next DAYS. Includes today."
daysd = dict(
self.col.db.all(
daysd: Dict[int, int] = dict(
self.col.db.all( # type: ignore
f"""
select due, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV}
Expand Down Expand Up @@ -542,7 +542,7 @@ def _fillLrn(self) -> Union[bool, List[Any]]:
if self._lrnQueue:
return True
cutoff = intTime() + self.col.conf["collapseTime"]
self._lrnQueue = self.col.db.all(
self._lrnQueue = self.col.db.all( # type: ignore
f"""
select due, id from cards where
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?
Expand Down
6 changes: 3 additions & 3 deletions pylib/anki/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import anki
from anki.consts import *
Expand All @@ -31,7 +31,7 @@ class UnexpectedSchemaChange(Exception):


class Syncer:
chunkRows: Optional[List[List]]
chunkRows: Optional[List[Sequence]]

def __init__(self, col: anki.storage._Collection, server=None) -> None:
self.col = col.weakref()
Expand Down Expand Up @@ -248,7 +248,7 @@ def prepareToChunk(self) -> None:
self.tablesLeft = ["revlog", "cards", "notes"]
self.chunkRows = None

def getChunkRows(self, table) -> List[List]:
def getChunkRows(self, table) -> List[Sequence]:
lim = self.usnLim()
x = self.col.db.all
d = (self.maxUsn, lim)
Expand Down

0 comments on commit 77cf7dd

Please sign in to comment.