From 5bd8a110ea4a665f0234dde20784c90f3b232101 Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Thu, 25 Nov 2021 13:30:40 +0100 Subject: [PATCH 1/5] Fix machine_type unique constraint to support a machine_type for multiple sites --- CONTRIBUTORS | 2 +- docs/source/changelog.rst | 12 ++++++++++-- tardis/plugins/sqliteregistry.py | 3 ++- tests/plugins_t/test_sqliteregistry.py | 22 ++++++++++++++++++---- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 30b4ad0e..55ad2c78 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -4,9 +4,9 @@ Manuel Giffels Stefan Kroboth Eileen Kuehn matthias.schnepf +Max Fischer ubdsv Rene Caspart -Max Fischer Leon Schuhmacher R. Florian von Cube mschnepf diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 3dfd2d94..5fdafdce 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,11 +1,19 @@ -.. Created by changelog.py at 2021-10-06, command - '/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-python3.9/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst' +.. Created by changelog.py at 2021-11-25, command + '/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-default/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst' based on the format of 'https://keepachangelog.com/' ######### CHANGELOG ######### +[Unreleased] - 2021-11-25 +========================= + +Changed +------- + +* SSHExecutor respects the remote MaxSessions via queueing + [0.6.0] - 2021-08-09 ==================== diff --git a/tardis/plugins/sqliteregistry.py b/tardis/plugins/sqliteregistry.py index ff86dbd4..8a90d739 100644 --- a/tardis/plugins/sqliteregistry.py +++ b/tardis/plugins/sqliteregistry.py @@ -65,9 +65,10 @@ def _deploy_db_schema(self): tables = { "MachineTypes": [ "machine_type_id INTEGER PRIMARY KEY AUTOINCREMENT", - "machine_type VARCHAR(255) UNIQUE", + "machine_type VARCHAR(255)", "site_id INTEGER", "FOREIGN KEY(site_id) REFERENCES Sites(site_id)", + "CONSTRAINT unique_machine_type_per_site UNIQUE (machine_type, site_id)", # noqa B950 ], "Resources": [ "id INTEGER PRIMARY KEY AUTOINCREMENT," diff --git a/tests/plugins_t/test_sqliteregistry.py b/tests/plugins_t/test_sqliteregistry.py index 7e151969..269a8de9 100644 --- a/tests/plugins_t/test_sqliteregistry.py +++ b/tests/plugins_t/test_sqliteregistry.py @@ -16,9 +16,12 @@ class TestSqliteRegistry(TestCase): + mock_config_patcher = None + @classmethod def setUpClass(cls): cls.test_site_name = "MyGreatTestSite" + cls.other_test_site_name = "MyOtherTestSite" cls.test_machine_type = "MyGreatTestMachineType" cls.tables_in_db = {"MachineTypes", "Resources", "ResourceStates", "Sites"} cls.test_resource_attributes = { @@ -90,8 +93,11 @@ def setUp(self): def test_add_machine_types(self): registry = SqliteRegistry() - registry.add_site(self.test_site_name) - registry.add_machine_types(self.test_site_name, self.test_machine_type) + test_site_names = (self.test_site_name, self.other_test_site_name) + + for site_name in test_site_names: + registry.add_site(site_name) + registry.add_machine_types(site_name, self.test_machine_type) with sqlite3.connect(self.test_db) as connection: cursor = connection.cursor() @@ -99,8 +105,16 @@ def test_add_machine_types(self): """SELECT MachineTypes.machine_type, Sites.site_name FROM MachineTypes JOIN Sites ON MachineTypes.site_id=Sites.site_id""" ) - for row in cursor: - self.assertEqual(row, (self.test_machine_type, self.test_site_name)) + machine_types = cursor.fetchall() + + self.assertEqual( + len(test_site_names), + len(machine_types), + msg="Number of rows added to the database is different from the" + " numbers of rows retrieved from the database!", + ) + for machine_type, site_name in zip(machine_types, test_site_names): + self.assertEqual(machine_type, (self.test_machine_type, site_name)) def test_add_site(self): registry = SqliteRegistry() From 7ebb4472eaaec3a8277d537449239ec33c7b284a Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Mon, 29 Nov 2021 17:15:13 +0100 Subject: [PATCH 2/5] Allow same remote_resource_uuid on different sites and improve unittesting --- docs/source/changelog.rst | 4 +- tardis/plugins/sqliteregistry.py | 53 +++++-- tests/plugins_t/test_sqliteregistry.py | 197 ++++++++++++++++++------- 3 files changed, 182 insertions(+), 72 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 5fdafdce..4e2ea2f7 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,4 +1,4 @@ -.. Created by changelog.py at 2021-11-25, command +.. Created by changelog.py at 2021-11-29, command '/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-default/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst' based on the format of 'https://keepachangelog.com/' @@ -6,7 +6,7 @@ CHANGELOG ######### -[Unreleased] - 2021-11-25 +[Unreleased] - 2021-11-29 ========================= Changed diff --git a/tardis/plugins/sqliteregistry.py b/tardis/plugins/sqliteregistry.py index 8a90d739..d5f23375 100644 --- a/tardis/plugins/sqliteregistry.py +++ b/tardis/plugins/sqliteregistry.py @@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager +from typing import List, Dict import asyncio import logging import sqlite3 @@ -33,25 +34,48 @@ def __init__(self): for machine_type in getattr(configuration, site.name).MachineTypes: self.add_machine_types(site.name, machine_type) - def add_machine_types(self, site_name: str, machine_type: str): + def add_machine_types(self, site_name: str, machine_type: str) -> None: + if self.get_machine_type(site_name, machine_type): + logger.debug( + f"{machine_type} is already present for {site_name} in database! Skipping insertion!" # noqa B950 + ) + return sql_query = """ - INSERT OR IGNORE INTO MachineTypes(machine_type, site_id) + INSERT OR ROLLBACK INTO MachineTypes(machine_type, site_id) SELECT :machine_type, Sites.site_id FROM Sites WHERE Sites.site_name = :site_name""" self.execute(sql_query, dict(site_name=site_name, machine_type=machine_type)) - def add_site(self, site_name: str): - sql_query = "INSERT OR IGNORE INTO Sites(site_name) VALUES (:site_name)" + def get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]: + sql_query = """ + SELECT * FROM MachineTypes MT + JOIN Sites S ON MT.site_id = S.site_id + WHERE MT.machine_type = :machine_type AND S.site_name = :site_name""" + return self.execute( + sql_query, dict(site_name=site_name, machine_type=machine_type) + ) + + def add_site(self, site_name: str) -> None: + if self.get_site(site_name): + logger.debug( + f"{site_name} already present in database! Skipping insertion!" + ) + return + sql_query = "INSERT OR ROLLBACK INTO Sites(site_name) VALUES (:site_name)" self.execute(sql_query, dict(site_name=site_name)) - async def async_execute(self, sql_query: str, bind_parameters: dict): + def get_site(self, site_name: str) -> List[Dict]: + sql_query = "SELECT * FROM Sites WHERE site_name = :site_name" + return self.execute(sql_query, dict(site_name=site_name)) + + async def async_execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self.thread_pool_executor, self.execute, sql_query, bind_parameters ) @contextmanager - def connect(self): + def connect(self) -> None: con = sqlite3.connect( self._db_file, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES ) @@ -61,7 +85,7 @@ def connect(self): finally: con.close() - def _deploy_db_schema(self): + def _deploy_db_schema(self) -> None: tables = { "MachineTypes": [ "machine_type_id INTEGER PRIMARY KEY AUTOINCREMENT", @@ -72,7 +96,7 @@ def _deploy_db_schema(self): ], "Resources": [ "id INTEGER PRIMARY KEY AUTOINCREMENT," - "remote_resource_uuid VARCHAR(255) UNIQUE", + "remote_resource_uuid VARCHAR(255)", "drone_uuid VARCHAR(255) UNIQUE", "state_id INTEGER", "site_id INTEGER", @@ -82,6 +106,7 @@ def _deploy_db_schema(self): "FOREIGN KEY(state_id) REFERENCES ResourceState(state_id)", "FOREIGN KEY(site_id) REFERENCES Sites(site_id)", "FOREIGN KEY(machine_type_id) REFERENCES MachineTypes(machine_type_id)", + "CONSTRAINT unique_remote_resource_uuid_per_site UNIQUE (site_id, remote_resource_uuid)", # noqa B950 ], "ResourceStates": [ "state_id INTEGER PRIMARY KEY AUTOINCREMENT", @@ -109,13 +134,13 @@ def _deploy_db_schema(self): (state,), ) - async def delete_resource(self, bind_parameters: dict): + async def delete_resource(self, bind_parameters: dict) -> None: sql_query = """DELETE FROM Resources WHERE drone_uuid = :drone_uuid AND site_id = (SELECT site_id from Sites WHERE site_name = :site_name)""" await self.async_execute(sql_query, bind_parameters) - def execute(self, sql_query: str, bind_parameters: dict): + def execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: with self.connect() as connection: connection.row_factory = lambda cur, row: { col[0]: row[idx] for idx, col in enumerate(cur.description) @@ -125,7 +150,7 @@ def execute(self, sql_query: str, bind_parameters: dict): logger.debug(f"{sql_query},{bind_parameters} executed") return cursor.fetchall() - def get_resources(self, site_name: str, machine_type: str): + def get_resources(self, site_name: str, machine_type: str) -> List[Dict]: sql_query = """ SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, R.created, R.updated FROM Resources R @@ -137,9 +162,9 @@ def get_resources(self, site_name: str, machine_type: str): sql_query, dict(site_name=site_name, machine_type=machine_type) ) - async def insert_resource(self, bind_parameters: dict): + async def insert_resource(self, bind_parameters: dict) -> None: sql_query = """ - INSERT OR IGNORE INTO + INSERT OR ROLLBACK INTO Resources(remote_resource_uuid, drone_uuid, state_id, site_id, machine_type_id, created, updated) SELECT :remote_resource_uuid, :drone_uuid, RS.state_id, S.site_id, @@ -158,7 +183,7 @@ async def notify(self, state: State, resource_attributes: AttributeDict) -> None bind_parameters.update(resource_attributes) await self._dispatch_on_state.get(state, self.update_resource)(bind_parameters) - async def update_resource(self, bind_parameters: dict): + async def update_resource(self, bind_parameters: dict) -> None: sql_query = """UPDATE Resources SET updated = :updated, state_id = (SELECT state_id FROM ResourceStates WHERE state = :state) WHERE drone_uuid = :drone_uuid diff --git a/tests/plugins_t/test_sqliteregistry.py b/tests/plugins_t/test_sqliteregistry.py index 269a8de9..130acd7e 100644 --- a/tests/plugins_t/test_sqliteregistry.py +++ b/tests/plugins_t/test_sqliteregistry.py @@ -1,3 +1,5 @@ +import logging + from tardis.resources.dronestates import BootingState from tardis.resources.dronestates import IntegrateState from tardis.resources.dronestates import DownState @@ -91,21 +93,28 @@ def setUp(self): config.Sites = [AttributeDict(name=self.test_site_name)] getattr(config, self.test_site_name).MachineTypes = [self.test_machine_type] + self.registry = SqliteRegistry() + + def execute_db_query(self, sql_query): + with sqlite3.connect(self.test_db) as connection: + cursor = connection.cursor() + cursor.execute(sql_query) + + return cursor.fetchall() + def test_add_machine_types(self): - registry = SqliteRegistry() test_site_names = (self.test_site_name, self.other_test_site_name) for site_name in test_site_names: - registry.add_site(site_name) - registry.add_machine_types(site_name, self.test_machine_type) - - with sqlite3.connect(self.test_db) as connection: - cursor = connection.cursor() - cursor.execute( - """SELECT MachineTypes.machine_type, Sites.site_name FROM MachineTypes - JOIN Sites ON MachineTypes.site_id=Sites.site_id""" + self.registry.add_site(site_name) + self.registry.add_machine_types(site_name, self.test_machine_type) + + def check_db_content(): + machine_types = self.execute_db_query( + sql_query="""SELECT MachineTypes.machine_type, Sites.site_name + FROM MachineTypes + JOIN Sites ON MachineTypes.site_id=Sites.site_id""" ) - machine_types = cursor.fetchall() self.assertEqual( len(test_site_names), @@ -113,30 +122,52 @@ def test_add_machine_types(self): msg="Number of rows added to the database is different from the" " numbers of rows retrieved from the database!", ) - for machine_type, site_name in zip(machine_types, test_site_names): - self.assertEqual(machine_type, (self.test_machine_type, site_name)) + + self.assertListEqual( + [(self.test_machine_type, site_name) for site_name in test_site_names], + machine_types, + ) + + check_db_content() + + with self.assertLogs( + logger="cobald.runtime.tardis.plugins.sqliteregistry", level=logging.DEBUG + ): + self.registry.add_machine_types(self.test_site_name, self.test_machine_type) + + check_db_content() def test_add_site(self): - registry = SqliteRegistry() - registry.add_site(self.test_site_name) + test_site_names = (self.test_site_name, self.other_test_site_name) + self.registry.add_site(test_site_names[0]) - with sqlite3.connect(self.test_db) as connection: - cursor = connection.cursor() - cursor.execute("SELECT site_name FROM Sites") - for row in cursor: - self.assertEqual(row[0], self.test_site_name) + def check_db_content(): + for row, site_name in zip( + self.execute_db_query("SELECT site_name FROM Sites"), test_site_names + ): + self.assertEqual(row[0], site_name) - def test_connect(self): - SqliteRegistry() + check_db_content() - with sqlite3.connect(self.test_db) as connection: - cursor = connection.cursor() - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - created_tables = { - table_name[0] - for table_name in cursor.fetchall() - if table_name[0] != "sqlite_sequence" - } + with self.assertLogs( + logger="cobald.runtime.tardis.plugins.sqliteregistry", level=logging.DEBUG + ): + self.registry.add_site(test_site_names[0]) + + check_db_content() + + self.registry.add_site(test_site_names[1]) + + check_db_content() + + def test_connect(self): + created_tables = { + table_name[0] + for table_name in self.execute_db_query( + sql_query="SELECT name FROM sqlite_master WHERE type='table'" + ) + if table_name[0] != "sqlite_sequence" + } self.assertEqual(created_tables, self.tables_in_db) def test_double_schema_deployment(self): @@ -145,13 +176,12 @@ def test_double_schema_deployment(self): @patch("tardis.plugins.sqliteregistry.logging", Mock()) def test_get_resources(self): - registry = SqliteRegistry() - registry.add_site(self.test_site_name) - registry.add_machine_types(self.test_site_name, self.test_machine_type) - run_async(registry.notify, BootingState(), self.test_resource_attributes) + self.registry.add_site(self.test_site_name) + self.registry.add_machine_types(self.test_site_name, self.test_machine_type) + run_async(self.registry.notify, BootingState(), self.test_resource_attributes) self.assertListEqual( - registry.get_resources( + self.registry.get_resources( site_name=self.test_site_name, machine_type=self.test_machine_type ), [self.test_get_resources_result], @@ -159,43 +189,98 @@ def test_get_resources(self): @patch("tardis.plugins.sqliteregistry.logging", Mock()) def test_notify(self): - def fetch_row(db): - with sqlite3.connect(db) as connection: - cursor = connection.cursor() - cursor.execute( - """SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, + def fetch_all(): + return self.execute_db_query( + sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, S.site_name, MT.machine_type, R.created, R.updated FROM Resources R JOIN ResourceStates RS ON R.state_id = RS.state_id JOIN Sites S ON R.site_id = S.site_id JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id""" - ) - return cursor.fetchone() + ) + + self.registry.add_site(self.test_site_name) + self.registry.add_machine_types(self.test_site_name, self.test_machine_type) + + run_async(self.registry.notify, BootingState(), self.test_resource_attributes) - registry = SqliteRegistry() - registry.add_site(self.test_site_name) - registry.add_machine_types(self.test_site_name, self.test_machine_type) + self.assertEqual([self.test_notify_result], fetch_all()) - run_async(registry.notify, BootingState(), self.test_resource_attributes) + with self.assertRaises(sqlite3.IntegrityError) as ie: + run_async( + self.registry.notify, BootingState(), self.test_resource_attributes + ) + self.assertTrue("UNIQUE constraint failed" in str(ie.exception)) + + run_async( + self.registry.notify, + IntegrateState(), + self.test_updated_resource_attributes, + ) - self.assertEqual(self.test_notify_result, fetch_row(self.test_db)) + self.assertEqual([self.test_updated_notify_result], fetch_all()) run_async( - registry.notify, IntegrateState(), self.test_updated_resource_attributes + self.registry.notify, DownState(), self.test_updated_resource_attributes ) - self.assertEqual(self.test_updated_notify_result, fetch_row(self.test_db)) + self.assertListEqual([], fetch_all()) - run_async(registry.notify, DownState(), self.test_updated_resource_attributes) + def test_insert_resources(self): + def fetch_all(): + return self.execute_db_query( + sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, + S.site_name, MT.machine_type, R.created, R.updated + FROM Resources R + JOIN ResourceStates RS ON R.state_id = RS.state_id + JOIN Sites S ON R.site_id = S.site_id + JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id""" + ) - self.assertIsNone(fetch_row(self.test_db)) + test_site_names = (self.test_site_name, self.other_test_site_name) + for site_name in test_site_names: + self.registry.add_site(site_name) + self.registry.add_machine_types(site_name, self.test_machine_type) - def test_resource_status(self): - SqliteRegistry() + bind_parameters = dict(state="BootingState", **self.test_resource_attributes) - with sqlite3.connect(self.test_db) as connection: - cursor = connection.cursor() - cursor.execute("SELECT state FROM ResourceStates") - status = {row[0] for row in cursor.fetchall()} + run_async(self.registry.insert_resource, bind_parameters) + + self.assertListEqual([self.test_notify_result], fetch_all()) + + with self.assertRaises(sqlite3.IntegrityError) as ie: + run_async(self.registry.insert_resource, bind_parameters) + self.assertTrue("UNIQUE constraint failed" in str(ie.exception)) + + self.assertListEqual([self.test_notify_result], fetch_all()) + + # Test same remote_resource_uuids on different sites + bind_parameters = dict(state="BootingState", **self.test_resource_attributes) + bind_parameters["drone_uuid"] = f"{self.other_test_site_name}-045285abef1" + bind_parameters["site_name"] = self.other_test_site_name + + run_async(self.registry.insert_resource, bind_parameters) + + other_test_notify_result = ( + self.test_resource_attributes["remote_resource_uuid"], + f"{self.other_test_site_name}-045285abef1", + str(BootingState()), + self.other_test_site_name, + self.test_resource_attributes["machine_type"], + str(self.test_resource_attributes["created"]), + str(self.test_resource_attributes["updated"]), + ) + + self.assertListEqual( + [self.test_notify_result, other_test_notify_result], fetch_all() + ) + + def test_resource_status(self): + status = { + row[0] + for row in self.execute_db_query( + sql_query="SELECT state FROM ResourceStates" + ) + } self.assertEqual(status, {state for state in State.get_all_states()}) From 62ad6d5a92a1751b9299a29d17d7f0e65743b192 Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Mon, 29 Nov 2021 17:43:45 +0100 Subject: [PATCH 3/5] =?UTF-8?q?Time=20to=20get=20rid=20of=20Python3.6=20?= =?UTF-8?q?=F0=9F=98=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/plugins_t/test_sqliteregistry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins_t/test_sqliteregistry.py b/tests/plugins_t/test_sqliteregistry.py index 130acd7e..883cb5b1 100644 --- a/tests/plugins_t/test_sqliteregistry.py +++ b/tests/plugins_t/test_sqliteregistry.py @@ -210,7 +210,7 @@ def fetch_all(): run_async( self.registry.notify, BootingState(), self.test_resource_attributes ) - self.assertTrue("UNIQUE constraint failed" in str(ie.exception)) + self.assertTrue("unique" in str(ie.exception).lower()) run_async( self.registry.notify, @@ -250,7 +250,7 @@ def fetch_all(): with self.assertRaises(sqlite3.IntegrityError) as ie: run_async(self.registry.insert_resource, bind_parameters) - self.assertTrue("UNIQUE constraint failed" in str(ie.exception)) + self.assertTrue("unique" in str(ie.exception).lower()) self.assertListEqual([self.test_notify_result], fetch_all()) From e6182c61b1e8c1214063b93b9cb9558b66725e1e Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Tue, 30 Nov 2021 11:12:38 +0100 Subject: [PATCH 4/5] Add message to changelog --- docs/source/changelog.rst | 9 +++++++-- .../changes/220.fix_unique_constraints_db_schema.yaml | 9 +++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 docs/source/changes/220.fix_unique_constraints_db_schema.yaml diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 4e2ea2f7..63736fb3 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,4 +1,4 @@ -.. Created by changelog.py at 2021-11-29, command +.. Created by changelog.py at 2021-11-30, command '/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-default/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst' based on the format of 'https://keepachangelog.com/' @@ -6,7 +6,7 @@ CHANGELOG ######### -[Unreleased] - 2021-11-29 +[Unreleased] - 2021-11-30 ========================= Changed @@ -14,6 +14,11 @@ Changed * SSHExecutor respects the remote MaxSessions via queueing +Fixed +----- + +* Unique constraints in database schema have been fixed to allow same machine_type and remote_resource_uuid on multiple sites + [0.6.0] - 2021-08-09 ==================== diff --git a/docs/source/changes/220.fix_unique_constraints_db_schema.yaml b/docs/source/changes/220.fix_unique_constraints_db_schema.yaml new file mode 100644 index 00000000..05bcedc0 --- /dev/null +++ b/docs/source/changes/220.fix_unique_constraints_db_schema.yaml @@ -0,0 +1,9 @@ +category: fixed +summary: "Unique constraints in database schema have been fixed to allow same machine_type and remote_resource_uuid on multiple sites" +description: | + The unique constraints in the datebase schema have been relaxed to allow the same machine_type and the same + remote_resource_uuid to be used on multiple sites. In addition, the unittest of the SqliteRegistry have been improved. +pull_requests: + - 220 +issues: + - 219 From 4c26b800aa8f7bb6b0bd6b0f7b231d20913270ee Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Tue, 30 Nov 2021 15:15:06 +0100 Subject: [PATCH 5/5] Add suggestions from code review in #220 --- tardis/plugins/sqliteregistry.py | 42 ++++++++++++++------------ tests/plugins_t/test_sqliteregistry.py | 14 +++++++-- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/tardis/plugins/sqliteregistry.py b/tardis/plugins/sqliteregistry.py index d5f23375..f24bc8b0 100644 --- a/tardis/plugins/sqliteregistry.py +++ b/tardis/plugins/sqliteregistry.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import List, Dict +from typing import List, Dict, Generator import asyncio import logging import sqlite3 @@ -24,9 +24,11 @@ def __init__(self): configuration = Configuration() self._db_file = configuration.Plugins.SqliteRegistry.db_file self._deploy_db_schema() - self._dispatch_on_state = dict( - BootingState=self.insert_resource, DownState=self.delete_resource - ) + self._dispatch_on_state = { + "BootingState": self.insert_resource, + "DownState": self.delete_resource, + } + self.thread_pool_executor = ThreadPoolExecutor(max_workers=1) for site in configuration.Sites: @@ -35,7 +37,7 @@ def __init__(self): self.add_machine_types(site.name, machine_type) def add_machine_types(self, site_name: str, machine_type: str) -> None: - if self.get_machine_type(site_name, machine_type): + if self._get_machine_type(site_name, machine_type): logger.debug( f"{machine_type} is already present for {site_name} in database! Skipping insertion!" # noqa B950 ) @@ -44,38 +46,38 @@ def add_machine_types(self, site_name: str, machine_type: str) -> None: INSERT OR ROLLBACK INTO MachineTypes(machine_type, site_id) SELECT :machine_type, Sites.site_id FROM Sites WHERE Sites.site_name = :site_name""" - self.execute(sql_query, dict(site_name=site_name, machine_type=machine_type)) + self.execute(sql_query, {"site_name": site_name, "machine_type": machine_type}) - def get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]: + def _get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]: sql_query = """ SELECT * FROM MachineTypes MT JOIN Sites S ON MT.site_id = S.site_id WHERE MT.machine_type = :machine_type AND S.site_name = :site_name""" return self.execute( - sql_query, dict(site_name=site_name, machine_type=machine_type) + sql_query, {"site_name": site_name, "machine_type": machine_type} ) def add_site(self, site_name: str) -> None: - if self.get_site(site_name): + if self._get_site(site_name): logger.debug( f"{site_name} already present in database! Skipping insertion!" ) return sql_query = "INSERT OR ROLLBACK INTO Sites(site_name) VALUES (:site_name)" - self.execute(sql_query, dict(site_name=site_name)) + self.execute(sql_query, {"site_name": site_name}) - def get_site(self, site_name: str) -> List[Dict]: + def _get_site(self, site_name: str) -> List[Dict]: sql_query = "SELECT * FROM Sites WHERE site_name = :site_name" - return self.execute(sql_query, dict(site_name=site_name)) + return self.execute(sql_query, {"site_name": site_name}) - async def async_execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: + async def async_execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self.thread_pool_executor, self.execute, sql_query, bind_parameters ) @contextmanager - def connect(self) -> None: + def connect(self) -> Generator[sqlite3.Connection, None, None]: con = sqlite3.connect( self._db_file, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES ) @@ -134,13 +136,13 @@ def _deploy_db_schema(self) -> None: (state,), ) - async def delete_resource(self, bind_parameters: dict) -> None: + async def delete_resource(self, bind_parameters: Dict) -> None: sql_query = """DELETE FROM Resources WHERE drone_uuid = :drone_uuid AND site_id = (SELECT site_id from Sites WHERE site_name = :site_name)""" await self.async_execute(sql_query, bind_parameters) - def execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: + def execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]: with self.connect() as connection: connection.row_factory = lambda cur, row: { col[0]: row[idx] for idx, col in enumerate(cur.description) @@ -159,10 +161,10 @@ def get_resources(self, site_name: str, machine_type: str) -> List[Dict]: JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id WHERE S.site_name = :site_name AND MT.machine_type = :machine_type""" return self.execute( - sql_query, dict(site_name=site_name, machine_type=machine_type) + sql_query, {"site_name": site_name, "machine_type": machine_type} ) - async def insert_resource(self, bind_parameters: dict) -> None: + async def insert_resource(self, bind_parameters: Dict) -> None: sql_query = """ INSERT OR ROLLBACK INTO Resources(remote_resource_uuid, drone_uuid, state_id, site_id, machine_type_id, @@ -179,11 +181,11 @@ async def insert_resource(self, bind_parameters: dict) -> None: async def notify(self, state: State, resource_attributes: AttributeDict) -> None: state = str(state) logger.debug(f"Drone: {str(resource_attributes)} has changed state to {state}") - bind_parameters = dict(state=state) + bind_parameters = {"state": state} bind_parameters.update(resource_attributes) await self._dispatch_on_state.get(state, self.update_resource)(bind_parameters) - async def update_resource(self, bind_parameters: dict) -> None: + async def update_resource(self, bind_parameters: Dict) -> None: sql_query = """UPDATE Resources SET updated = :updated, state_id = (SELECT state_id FROM ResourceStates WHERE state = :state) WHERE drone_uuid = :drone_uuid diff --git a/tests/plugins_t/test_sqliteregistry.py b/tests/plugins_t/test_sqliteregistry.py index 883cb5b1..0a0d4d47 100644 --- a/tests/plugins_t/test_sqliteregistry.py +++ b/tests/plugins_t/test_sqliteregistry.py @@ -109,6 +109,8 @@ def test_add_machine_types(self): self.registry.add_site(site_name) self.registry.add_machine_types(site_name, self.test_machine_type) + # Database content has to be checked several times + # Define inline function to re-use code def check_db_content(): machine_types = self.execute_db_query( sql_query="""SELECT MachineTypes.machine_type, Sites.site_name @@ -141,6 +143,8 @@ def test_add_site(self): test_site_names = (self.test_site_name, self.other_test_site_name) self.registry.add_site(test_site_names[0]) + # Database content has to be checked several times + # Define inline function to re-use code def check_db_content(): for row, site_name in zip( self.execute_db_query("SELECT site_name FROM Sites"), test_site_names @@ -189,6 +193,8 @@ def test_get_resources(self): @patch("tardis.plugins.sqliteregistry.logging", Mock()) def test_notify(self): + # Database has to be queried multiple times + # Define inline function to re-use code def fetch_all(): return self.execute_db_query( sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, @@ -227,6 +233,8 @@ def fetch_all(): self.assertListEqual([], fetch_all()) def test_insert_resources(self): + # Database has to be queried multiple times + # Define inline function to re-use code def fetch_all(): return self.execute_db_query( sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, @@ -242,7 +250,8 @@ def fetch_all(): self.registry.add_site(site_name) self.registry.add_machine_types(site_name, self.test_machine_type) - bind_parameters = dict(state="BootingState", **self.test_resource_attributes) + bind_parameters = {"state": "BootingState"} + bind_parameters.update(self.test_resource_attributes) run_async(self.registry.insert_resource, bind_parameters) @@ -255,7 +264,8 @@ def fetch_all(): self.assertListEqual([self.test_notify_result], fetch_all()) # Test same remote_resource_uuids on different sites - bind_parameters = dict(state="BootingState", **self.test_resource_attributes) + bind_parameters = {"state": "BootingState"} + bind_parameters.update(self.test_resource_attributes) bind_parameters["drone_uuid"] = f"{self.other_test_site_name}-045285abef1" bind_parameters["site_name"] = self.other_test_site_name