-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement database reset methods for tests
Generated-by: aiautocommit
- Loading branch information
1 parent
4626861
commit 9cec29d
Showing
3 changed files
with
80 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,2 @@ | ||
from sqlmodel import SQLModel | ||
|
||
from activemodel import logger | ||
|
||
from ..session_manager import get_engine | ||
|
||
|
||
def truncate_db(): | ||
# TODO Problem with truncation is you can't run multiple tests in parallel without separate containers | ||
|
||
logger.info("Truncating database") | ||
|
||
# TODO get additonal tables to preserve from config | ||
exception_tables = ["alembic_version"] | ||
|
||
assert ( | ||
SQLModel.metadata.sorted_tables | ||
), "No model metadata. Ensure model metadata is imported before running truncate_db" | ||
|
||
with get_engine().connect() as connection: | ||
for table in reversed(SQLModel.metadata.sorted_tables): | ||
transaction = connection.begin() | ||
|
||
if table.name not in exception_tables: | ||
logger.debug("truncating table=%s", table.name) | ||
connection.execute(table.delete()) | ||
|
||
transaction.commit() | ||
from .transaction import database_reset_transaction | ||
from .truncate import database_reset_truncate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from activemodel import SessionManager | ||
|
||
|
||
def database_reset_transaction(): | ||
""" | ||
Wrap all database interactions for a given test in a nested transaction and roll it back after the test. | ||
>>> from activemodel.pytest import database_reset_transaction | ||
>>> pytest.fixture(scope="function", autouse=True)(database_reset_transaction) | ||
References: | ||
- https://stackoverflow.com/questions/62433018/how-to-make-sqlalchemy-transaction-rollback-drop-tables-it-created | ||
- https://aalvarez.me/posts/setting-up-a-sqlalchemy-and-pytest-based-test-suite/ | ||
- https://github.com/nickjj/docker-flask-example/blob/93af9f4fbf185098ffb1d120ee0693abcd77a38b/test/conftest.py#L77 | ||
- https://github.com/caiola/vinhos.com/blob/c47d0a5d7a4bf290c1b726561d1e8f5d2ac29bc8/backend/test/conftest.py#L46 | ||
""" | ||
|
||
engine = SessionManager.get_instance().get_engine() | ||
|
||
with engine.begin() as connection: | ||
transaction = connection.begin_nested() | ||
|
||
SessionManager.get_instance().session_connection = connection | ||
|
||
try: | ||
yield | ||
finally: | ||
transaction.rollback() | ||
# TODO is this necessary? | ||
connection.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from sqlmodel import SQLModel | ||
|
||
from activemodel import logger | ||
|
||
from ..session_manager import get_engine | ||
|
||
|
||
def database_reset_truncate(): | ||
""" | ||
Transaction is most likely the better way to go, but there are some scenarios where the session override | ||
logic does not work properly and you need to truncate tables back to their original state. | ||
Here's how to do this once at the start of the test: | ||
>>> from activemodel.pytest import database_reset_truncation | ||
>>> def pytest_configure(config): | ||
>>> database_reset_truncation() | ||
Or, if you want to use this as a fixture: | ||
>>> pytest.fixture(scope="function")(database_reset_truncation) | ||
>>> def test_the_thing(database_reset_truncation) | ||
This approach has a couple of problems: | ||
* You can't run multiple tests in parallel without separate databases | ||
* If you have important seed data and want to truncate those tables, the seed data will be lost | ||
""" | ||
|
||
logger.info("truncating database") | ||
|
||
# TODO get additonal tables to preserve from config | ||
exception_tables = ["alembic_version"] | ||
|
||
assert ( | ||
SQLModel.metadata.sorted_tables | ||
), "No model metadata. Ensure model metadata is imported before running truncate_db" | ||
|
||
with get_engine().connect() as connection: | ||
for table in reversed(SQLModel.metadata.sorted_tables): | ||
transaction = connection.begin() | ||
|
||
if table.name not in exception_tables: | ||
logger.debug("truncating table=%s", table.name) | ||
connection.execute(table.delete()) | ||
|
||
transaction.commit() |