diff --git a/qcodes/dataset/sqlite_base.py b/qcodes/dataset/sqlite_base.py index 7edf3ef91f3..2dde25d5f9e 100644 --- a/qcodes/dataset/sqlite_base.py +++ b/qcodes/dataset/sqlite_base.py @@ -30,6 +30,14 @@ VALUE = Union[str, Number, List, ndarray, bool] VALUES = List[VALUE] + +# Functions decorated as 'upgrader' are inserted into this dict +# The newest database version is thus determined by the number of upgrades +# in this module +# The key is the TARGET VERSION of the upgrade, i.e. the first key is 1 +_UPGRADE_ACTIONS: Dict[int, Callable] = {} + + _experiment_table_schema = """ CREATE TABLE IF NOT EXISTS experiments ( -- this will autoncrement by default if @@ -172,6 +180,8 @@ def do_upgrade(conn: SomeConnection) -> None: log.info(f'Succesfully performed upgrade {from_version} ' f'-> {to_version}') + _UPGRADE_ACTIONS[to_version] = do_upgrade + return do_upgrade @@ -354,16 +364,14 @@ def perform_db_upgrade(conn: SomeConnection, version: int=-1) -> None: 'newest version' """ - upgrade_actions = [perform_db_upgrade_0_to_1, perform_db_upgrade_1_to_2, - perform_db_upgrade_2_to_3] - newest_version = len(upgrade_actions) + newest_version = len(_UPGRADE_ACTIONS) version = newest_version if version == -1 else version current_version = get_user_version(conn) - if current_version < newest_version: + if current_version < version: log.info("Commencing database upgrade") - for action in upgrade_actions[:version]: - action(conn) + for target_version in sorted(_UPGRADE_ACTIONS)[:version]: + _UPGRADE_ACTIONS[target_version](conn) @upgrader @@ -684,6 +692,25 @@ def perform_db_upgrade_2_to_3(conn: SomeConnection) -> None: log.debug(f"Upgrade in transition, run number {run_id}: OK") +def get_db_version_and_newest_available_version(path_to_db: str) -> Tuple[int, + int]: + """ + Connect to a DB without performing any upgrades and get the version of + that database file along with the newest available version (the one that + a normal "connect" will automatically upgrade to) + + Args: + path_to_db: the absolute path to the DB file + + Returns: + A tuple of (db_version, latest_available_version) + """ + conn = connect(path_to_db, version=0) + db_version = get_user_version(conn) + + return (db_version, len(_UPGRADE_ACTIONS)) + + def transaction(conn: SomeConnection, sql: str, *args: Any) -> sqlite3.Cursor: """Perform a transaction. diff --git a/qcodes/tests/dataset/test_database_creation_and_upgrading.py b/qcodes/tests/dataset/test_database_creation_and_upgrading.py index 71b77b06180..9bc668239fa 100644 --- a/qcodes/tests/dataset/test_database_creation_and_upgrading.py +++ b/qcodes/tests/dataset/test_database_creation_and_upgrading.py @@ -21,6 +21,7 @@ from qcodes.dataset.sqlite_base import (connect, one, update_GUIDs, + get_db_version_and_newest_available_version, get_user_version, atomic_transaction, perform_db_upgrade_0_to_1, @@ -358,3 +359,20 @@ def test_update_existing_guids(caplog): guid_comps_5 = parse_guid(ds5.guid) assert guid_comps_5['location'] == old_loc assert guid_comps_5['work_station'] == old_ws + + +@pytest.mark.parametrize('version', [0, 1, 2]) +def test_getting_db_version(version): + + fixpath = os.path.join(fixturepath, 'db_files', f'version{version}') + + if not os.path.exists(fixpath): + pytest.skip("No db-file fixtures found. You can generate test db-files" + " using the scripts in the legacy_DB_generation folder") + + dbname = os.path.join(fixpath, 'empty.db') + + (db_v, new_v) = get_db_version_and_newest_available_version(dbname) + + assert db_v == version + assert new_v == 3