diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 833c474b53e..6a3e94bf9fe 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -23,7 +23,7 @@ from anki.decks import DeckManager from anki.errors import AnkiError from anki.lang import _ -from anki.media import MediaManager +from anki.media import MediaManager, media_paths_from_col_path from anki.models import ModelManager from anki.notes import Note from anki.rsbackend import TR, DBError, RustBackend @@ -41,18 +41,19 @@ class _Collection: def __init__( self, - db: DBProxy, - backend: RustBackend, + path: str, + backend: Optional[RustBackend], server: bool = False, log: bool = False, ) -> None: - self.backend = backend - self._debugLog = log - self.db = db - self.path = db._path - self._openLog() - self.log(self.path, anki.version) + self.backend = backend or RustBackend(server=server) + self.db = None + self._should_log = log self.server = server + self.path = os.path.abspath(path) + self.reopen() + + self.log(self.path, anki.version) self._lastSave = time.time() self.clearUndo() self.media = MediaManager(self, server) @@ -219,6 +220,24 @@ def rollback(self) -> None: self.db.rollback() self.db.begin() + def reopen(self) -> None: + assert not self.db + assert self.path.endswith(".anki2") + + (media_dir, media_db) = media_paths_from_col_path(self.path) + + log_path = "" + should_log = not self.server and self._should_log + if should_log: + log_path = self.path.replace(".anki2", "2.log") + + # connect + self.backend.open_collection(self.path, media_dir, media_db, log_path) + self.db = DBProxy(weakref.proxy(self.backend)) + self.db.begin() + + self._openLog() + def modSchema(self, check: bool) -> None: "Mark schema modified. Call this first so user can abort if necessary." if not self.schemaChanged(): @@ -586,7 +605,7 @@ def optimize(self) -> None: ########################################################################## def log(self, *args, **kwargs) -> None: - if not self._debugLog: + if not self._should_log: return def customRepr(x): @@ -606,7 +625,7 @@ def customRepr(x): print(buf) def _openLog(self) -> None: - if not self._debugLog: + if not self._should_log: return lpath = re.sub(r"\.anki2$", ".log", self.path) if os.path.exists(lpath) and os.path.getsize(lpath) > 10 * 1024 * 1024: @@ -617,7 +636,7 @@ def _openLog(self) -> None: self._logHnd = open(lpath, "a", encoding="utf8") def _closeLog(self) -> None: - if not self._debugLog: + if not self._should_log: return self._logHnd.close() self._logHnd = None diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py index 3c989720d80..e555c74b359 100644 --- a/pylib/anki/dbproxy.py +++ b/pylib/anki/dbproxy.py @@ -21,9 +21,8 @@ class DBProxy: # Lifecycle ############### - def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None: + def __init__(self, backend: anki.rsbackend.RustBackend) -> None: self._backend = backend - self._path = path self.mod = False self.last_begin_at = 0 diff --git a/pylib/anki/storage.py b/pylib/anki/storage.py index 1ab88f9550f..2fa407f673f 100644 --- a/pylib/anki/storage.py +++ b/pylib/anki/storage.py @@ -1,13 +1,9 @@ # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -import os -import weakref from typing import Optional from anki.collection import _Collection -from anki.dbproxy import DBProxy -from anki.media import media_paths_from_col_path from anki.rsbackend import RustBackend @@ -18,22 +14,4 @@ def Collection( log: bool = False, ) -> _Collection: "Open a new or existing collection. Path must be unicode." - assert path.endswith(".anki2") - if backend is None: - backend = RustBackend(server=server) - - (media_dir, media_db) = media_paths_from_col_path(path) - log_path = "" - should_log = not server and log - if should_log: - log_path = path.replace(".anki2", "2.log") - path = os.path.abspath(path) - - # connect - backend.open_collection(path, media_dir, media_db, log_path) - db = DBProxy(weakref.proxy(backend), path) - - # add db to col and do any remaining upgrades - col = _Collection(db, backend=backend, server=server) - db.begin() - return col + return _Collection(path, backend, server, log) diff --git a/qt/aqt/main.py b/qt/aqt/main.py index 0a70d3e5a10..6386de29d7d 100644 --- a/qt/aqt/main.py +++ b/qt/aqt/main.py @@ -505,12 +505,12 @@ def loadCollection(self) -> bool: return True def _loadCollection(self): - self.reopen() + cpath = self.pm.collectionPath() + self.col = Collection(cpath, backend=self.backend, log=True) self.setEnabled(True) def reopen(self): - cpath = self.pm.collectionPath() - self.col = Collection(cpath, backend=self.backend, log=True) + self.col.reopen() def unloadCollection(self, onsuccess: Callable) -> None: def callback():