diff --git a/tardis/plugins/sqliteregistry.py b/tardis/plugins/sqliteregistry.py index d5f23375..f24bc8b0 100644 --- a/tardis/plugins/sqliteregistry.py +++ b/tardis/plugins/sqliteregistry.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import List, Dict +from typing import List, Dict, Generator import asyncio import logging import sqlite3 @@ -24,9 +24,11 @@ def __init__(self): configuration = Configuration() self._db_file = configuration.Plugins.SqliteRegistry.db_file self._deploy_db_schema() - self._dispatch_on_state = dict( - BootingState=self.insert_resource, DownState=self.delete_resource - ) + self._dispatch_on_state = { + "BootingState": self.insert_resource, + "DownState": self.delete_resource, + } + self.thread_pool_executor = ThreadPoolExecutor(max_workers=1) for site in configuration.Sites: @@ -35,7 +37,7 @@ def __init__(self): self.add_machine_types(site.name, machine_type) def add_machine_types(self, site_name: str, machine_type: str) -> None: - if self.get_machine_type(site_name, machine_type): + if self._get_machine_type(site_name, machine_type): logger.debug( f"{machine_type} is already present for {site_name} in database! Skipping insertion!" # noqa B950 ) @@ -44,38 +46,38 @@ def add_machine_types(self, site_name: str, machine_type: str) -> None: INSERT OR ROLLBACK INTO MachineTypes(machine_type, site_id) SELECT :machine_type, Sites.site_id FROM Sites WHERE Sites.site_name = :site_name""" - self.execute(sql_query, dict(site_name=site_name, machine_type=machine_type)) + self.execute(sql_query, {"site_name": site_name, "machine_type": machine_type}) - def get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]: + def _get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]: sql_query = """ SELECT * FROM MachineTypes MT JOIN Sites S ON MT.site_id = S.site_id WHERE MT.machine_type = :machine_type AND S.site_name = :site_name""" return self.execute( - sql_query, dict(site_name=site_name, machine_type=machine_type) + sql_query, {"site_name": site_name, "machine_type": machine_type} ) def add_site(self, site_name: str) -> None: - if self.get_site(site_name): + if self._get_site(site_name): logger.debug( f"{site_name} already present in database! Skipping insertion!" ) return sql_query = "INSERT OR ROLLBACK INTO Sites(site_name) VALUES (:site_name)" - self.execute(sql_query, dict(site_name=site_name)) + self.execute(sql_query, {"site_name": site_name}) - def get_site(self, site_name: str) -> List[Dict]: + def _get_site(self, site_name: str) -> List[Dict]: sql_query = "SELECT * FROM Sites WHERE site_name = :site_name" - return self.execute(sql_query, dict(site_name=site_name)) + return self.execute(sql_query, {"site_name": site_name}) - async def async_execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: + async def async_execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self.thread_pool_executor, self.execute, sql_query, bind_parameters ) @contextmanager - def connect(self) -> None: + def connect(self) -> Generator[sqlite3.Connection, None, None]: con = sqlite3.connect( self._db_file, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES ) @@ -134,13 +136,13 @@ def _deploy_db_schema(self) -> None: (state,), ) - async def delete_resource(self, bind_parameters: dict) -> None: + async def delete_resource(self, bind_parameters: Dict) -> None: sql_query = """DELETE FROM Resources WHERE drone_uuid = :drone_uuid AND site_id = (SELECT site_id from Sites WHERE site_name = :site_name)""" await self.async_execute(sql_query, bind_parameters) - def execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]: + def execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]: with self.connect() as connection: connection.row_factory = lambda cur, row: { col[0]: row[idx] for idx, col in enumerate(cur.description) @@ -159,10 +161,10 @@ def get_resources(self, site_name: str, machine_type: str) -> List[Dict]: JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id WHERE S.site_name = :site_name AND MT.machine_type = :machine_type""" return self.execute( - sql_query, dict(site_name=site_name, machine_type=machine_type) + sql_query, {"site_name": site_name, "machine_type": machine_type} ) - async def insert_resource(self, bind_parameters: dict) -> None: + async def insert_resource(self, bind_parameters: Dict) -> None: sql_query = """ INSERT OR ROLLBACK INTO Resources(remote_resource_uuid, drone_uuid, state_id, site_id, machine_type_id, @@ -179,11 +181,11 @@ async def insert_resource(self, bind_parameters: dict) -> None: async def notify(self, state: State, resource_attributes: AttributeDict) -> None: state = str(state) logger.debug(f"Drone: {str(resource_attributes)} has changed state to {state}") - bind_parameters = dict(state=state) + bind_parameters = {"state": state} bind_parameters.update(resource_attributes) await self._dispatch_on_state.get(state, self.update_resource)(bind_parameters) - async def update_resource(self, bind_parameters: dict) -> None: + async def update_resource(self, bind_parameters: Dict) -> None: sql_query = """UPDATE Resources SET updated = :updated, state_id = (SELECT state_id FROM ResourceStates WHERE state = :state) WHERE drone_uuid = :drone_uuid diff --git a/tests/plugins_t/test_sqliteregistry.py b/tests/plugins_t/test_sqliteregistry.py index 883cb5b1..0a0d4d47 100644 --- a/tests/plugins_t/test_sqliteregistry.py +++ b/tests/plugins_t/test_sqliteregistry.py @@ -109,6 +109,8 @@ def test_add_machine_types(self): self.registry.add_site(site_name) self.registry.add_machine_types(site_name, self.test_machine_type) + # Database content has to be checked several times + # Define inline function to re-use code def check_db_content(): machine_types = self.execute_db_query( sql_query="""SELECT MachineTypes.machine_type, Sites.site_name @@ -141,6 +143,8 @@ def test_add_site(self): test_site_names = (self.test_site_name, self.other_test_site_name) self.registry.add_site(test_site_names[0]) + # Database content has to be checked several times + # Define inline function to re-use code def check_db_content(): for row, site_name in zip( self.execute_db_query("SELECT site_name FROM Sites"), test_site_names @@ -189,6 +193,8 @@ def test_get_resources(self): @patch("tardis.plugins.sqliteregistry.logging", Mock()) def test_notify(self): + # Database has to be queried multiple times + # Define inline function to re-use code def fetch_all(): return self.execute_db_query( sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, @@ -227,6 +233,8 @@ def fetch_all(): self.assertListEqual([], fetch_all()) def test_insert_resources(self): + # Database has to be queried multiple times + # Define inline function to re-use code def fetch_all(): return self.execute_db_query( sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state, @@ -242,7 +250,8 @@ def fetch_all(): self.registry.add_site(site_name) self.registry.add_machine_types(site_name, self.test_machine_type) - bind_parameters = dict(state="BootingState", **self.test_resource_attributes) + bind_parameters = {"state": "BootingState"} + bind_parameters.update(self.test_resource_attributes) run_async(self.registry.insert_resource, bind_parameters) @@ -255,7 +264,8 @@ def fetch_all(): self.assertListEqual([self.test_notify_result], fetch_all()) # Test same remote_resource_uuids on different sites - bind_parameters = dict(state="BootingState", **self.test_resource_attributes) + bind_parameters = {"state": "BootingState"} + bind_parameters.update(self.test_resource_attributes) bind_parameters["drone_uuid"] = f"{self.other_test_site_name}-045285abef1" bind_parameters["site_name"] = self.other_test_site_name