diff --git a/datasette/database.py b/datasette/database.py index 4a0babfb64..554f9fbf70 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -32,6 +32,7 @@ class Database: # For table counts stop at this many rows: count_limit = 10000 + _thread_local_id_counter = 1 def __init__( self, @@ -43,6 +44,8 @@ def __init__( mode=None, ): self.name = None + self._thread_local_id = f"x{self._thread_local_id_counter}" + Database._thread_local_id_counter += 1 self.route = None self.ds = ds self.path = path @@ -278,11 +281,11 @@ async def execute_fn(self, fn): # threaded mode def in_thread(): - conn = getattr(connections, self.name, None) + conn = getattr(connections, self._thread_local_id, None) if not conn: conn = self.connect() self.ds._prepare_connection(conn, self.name) - setattr(connections, self.name, conn) + setattr(connections, self._thread_local_id, conn) return fn(conn) return await asyncio.get_event_loop().run_in_executor( diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index edfc6bc7d9..eeaf8e9ada 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -721,3 +721,34 @@ async def test_hidden_tables(app_client): "r_parent", "r_rowid", ] + + +@pytest.mark.asyncio +async def test_replace_database(tmpdir): + path1 = str(tmpdir / "data1.db") + (tmpdir / "two").mkdir() + path2 = str(tmpdir / "two" / "data1.db") + sqlite3.connect(path1).executescript( + """ + create table t (id integer primary key); + insert into t (id) values (1); + insert into t (id) values (2); + """ + ) + sqlite3.connect(path2).executescript( + """ + create table t (id integer primary key); + insert into t (id) values (1); + """ + ) + datasette = Datasette([path1]) + db = datasette.get_database("data1") + count = (await db.execute("select count(*) from t")).first()[0] + assert count == 2 + # Now replace that database + datasette.get_database("data1").close() + datasette.remove_database("data1") + datasette.add_database(Database(datasette, path2), "data1") + db2 = datasette.get_database("data1") + count = (await db2.execute("select count(*) from t")).first()[0] + assert count == 1