From 77cf7dd4b7f1a7e8c607ef29b76a51dfcbc311c9 Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Tue, 3 Mar 2020 12:05:33 +1000 Subject: [PATCH] tweak db type hints --- pylib/anki/dbproxy.py | 24 +++++++++++++++++------- pylib/anki/schedv2.py | 6 +++--- pylib/anki/sync.py | 6 +++--- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py index 62e8ef4a4b8..be73a741091 100644 --- a/pylib/anki/dbproxy.py +++ b/pylib/anki/dbproxy.py @@ -5,7 +5,15 @@ # fixme: progress from sqlite3 import dbapi2 as sqlite -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List, Optional, Sequence, Union + +# DBValue is actually Union[str, int, float, None], but if defined +# that way, every call site needs to do a type check prior to using +# the return values. +ValueFromDB = Any +Row = Sequence[ValueFromDB] + +ValueForDB = Union[str, int, float, None] class DBProxy: @@ -38,7 +46,9 @@ def setAutocommit(self, autocommit: bool) -> None: # Querying ################ - def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]: + def _query( + self, sql: str, *args: ValueForDB, first_row_only: bool = False + ) -> List[Row]: # mark modified? s = sql.strip().lower() for stmt in "insert", "update", "delete": @@ -59,20 +69,20 @@ def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]: # Query shortcuts ################### - def all(self, sql: str, *args) -> List: + def all(self, sql: str, *args: ValueForDB) -> List[Row]: return self._query(sql, *args) - def list(self, sql: str, *args) -> List: + def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]: return [x[0] for x in self._query(sql, *args)] - def first(self, sql: str, *args) -> Optional[List]: + def first(self, sql: str, *args: ValueForDB) -> Optional[Row]: rows = self._query(sql, *args, first_row_only=True) if rows: return rows[0] else: return None - def scalar(self, sql: str, *args) -> Optional[Any]: + def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB: rows = self._query(sql, *args, first_row_only=True) if rows: return rows[0][0] @@ -86,7 +96,7 @@ def scalar(self, sql: str, *args) -> Optional[Any]: # Updates ################ - def executemany(self, sql: str, args: Iterable) -> None: + def executemany(self, sql: str, args: Iterable[Iterable[ValueForDB]]) -> None: self.mod = True self._db.executemany(sql, args) diff --git a/pylib/anki/schedv2.py b/pylib/anki/schedv2.py index d11f7808d8d..a9a7eadd239 100644 --- a/pylib/anki/schedv2.py +++ b/pylib/anki/schedv2.py @@ -138,8 +138,8 @@ def counts(self, card: Optional[Card] = None) -> Tuple[int, int, int]: def dueForecast(self, days: int = 7) -> List[Any]: "Return counts over next DAYS. Includes today." - daysd = dict( - self.col.db.all( + daysd: Dict[int, int] = dict( + self.col.db.all( # type: ignore f""" select due, count() from cards where did in %s and queue = {QUEUE_TYPE_REV} @@ -542,7 +542,7 @@ def _fillLrn(self) -> Union[bool, List[Any]]: if self._lrnQueue: return True cutoff = intTime() + self.col.conf["collapseTime"] - self._lrnQueue = self.col.db.all( + self._lrnQueue = self.col.db.all( # type: ignore f""" select due, id from cards where did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ? diff --git a/pylib/anki/sync.py b/pylib/anki/sync.py index 78459d520a7..396d8fd1a3e 100644 --- a/pylib/anki/sync.py +++ b/pylib/anki/sync.py @@ -8,7 +8,7 @@ import json import os import random -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import anki from anki.consts import * @@ -31,7 +31,7 @@ class UnexpectedSchemaChange(Exception): class Syncer: - chunkRows: Optional[List[List]] + chunkRows: Optional[List[Sequence]] def __init__(self, col: anki.storage._Collection, server=None) -> None: self.col = col.weakref() @@ -248,7 +248,7 @@ def prepareToChunk(self) -> None: self.tablesLeft = ["revlog", "cards", "notes"] self.chunkRows = None - def getChunkRows(self, table) -> List[List]: + def getChunkRows(self, table) -> List[Sequence]: lim = self.usnLim() x = self.col.db.all d = (self.maxUsn, lim)