Skip to content

Commit

Permalink
Add suggestions from code review in #220
Browse files Browse the repository at this point in the history
  • Loading branch information
giffels committed Nov 30, 2021
1 parent e6182c6 commit 4c26b80
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
42 changes: 22 additions & 20 deletions tardis/plugins/sqliteregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import List, Dict
from typing import List, Dict, Generator
import asyncio
import logging
import sqlite3
Expand All @@ -24,9 +24,11 @@ def __init__(self):
configuration = Configuration()
self._db_file = configuration.Plugins.SqliteRegistry.db_file
self._deploy_db_schema()
self._dispatch_on_state = dict(
BootingState=self.insert_resource, DownState=self.delete_resource
)
self._dispatch_on_state = {
"BootingState": self.insert_resource,
"DownState": self.delete_resource,
}

self.thread_pool_executor = ThreadPoolExecutor(max_workers=1)

for site in configuration.Sites:
Expand All @@ -35,7 +37,7 @@ def __init__(self):
self.add_machine_types(site.name, machine_type)

def add_machine_types(self, site_name: str, machine_type: str) -> None:
if self.get_machine_type(site_name, machine_type):
if self._get_machine_type(site_name, machine_type):
logger.debug(
f"{machine_type} is already present for {site_name} in database! Skipping insertion!" # noqa B950
)
Expand All @@ -44,38 +46,38 @@ def add_machine_types(self, site_name: str, machine_type: str) -> None:
INSERT OR ROLLBACK INTO MachineTypes(machine_type, site_id)
SELECT :machine_type, Sites.site_id FROM Sites
WHERE Sites.site_name = :site_name"""
self.execute(sql_query, dict(site_name=site_name, machine_type=machine_type))
self.execute(sql_query, {"site_name": site_name, "machine_type": machine_type})

def get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]:
def _get_machine_type(self, site_name: str, machine_type: str) -> List[Dict]:
sql_query = """
SELECT * FROM MachineTypes MT
JOIN Sites S ON MT.site_id = S.site_id
WHERE MT.machine_type = :machine_type AND S.site_name = :site_name"""
return self.execute(
sql_query, dict(site_name=site_name, machine_type=machine_type)
sql_query, {"site_name": site_name, "machine_type": machine_type}
)

def add_site(self, site_name: str) -> None:
if self.get_site(site_name):
if self._get_site(site_name):
logger.debug(
f"{site_name} already present in database! Skipping insertion!"
)
return
sql_query = "INSERT OR ROLLBACK INTO Sites(site_name) VALUES (:site_name)"
self.execute(sql_query, dict(site_name=site_name))
self.execute(sql_query, {"site_name": site_name})

def get_site(self, site_name: str) -> List[Dict]:
def _get_site(self, site_name: str) -> List[Dict]:
sql_query = "SELECT * FROM Sites WHERE site_name = :site_name"
return self.execute(sql_query, dict(site_name=site_name))
return self.execute(sql_query, {"site_name": site_name})

async def async_execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]:
async def async_execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.thread_pool_executor, self.execute, sql_query, bind_parameters
)

@contextmanager
def connect(self) -> None:
def connect(self) -> Generator[sqlite3.Connection, None, None]:
con = sqlite3.connect(
self._db_file, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
)
Expand Down Expand Up @@ -134,13 +136,13 @@ def _deploy_db_schema(self) -> None:
(state,),
)

async def delete_resource(self, bind_parameters: dict) -> None:
async def delete_resource(self, bind_parameters: Dict) -> None:
sql_query = """DELETE FROM Resources
WHERE drone_uuid = :drone_uuid
AND site_id = (SELECT site_id from Sites WHERE site_name = :site_name)"""
await self.async_execute(sql_query, bind_parameters)

def execute(self, sql_query: str, bind_parameters: dict) -> List[Dict]:
def execute(self, sql_query: str, bind_parameters: Dict) -> List[Dict]:
with self.connect() as connection:
connection.row_factory = lambda cur, row: {
col[0]: row[idx] for idx, col in enumerate(cur.description)
Expand All @@ -159,10 +161,10 @@ def get_resources(self, site_name: str, machine_type: str) -> List[Dict]:
JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id
WHERE S.site_name = :site_name AND MT.machine_type = :machine_type"""
return self.execute(
sql_query, dict(site_name=site_name, machine_type=machine_type)
sql_query, {"site_name": site_name, "machine_type": machine_type}
)

async def insert_resource(self, bind_parameters: dict) -> None:
async def insert_resource(self, bind_parameters: Dict) -> None:
sql_query = """
INSERT OR ROLLBACK INTO
Resources(remote_resource_uuid, drone_uuid, state_id, site_id, machine_type_id,
Expand All @@ -179,11 +181,11 @@ async def insert_resource(self, bind_parameters: dict) -> None:
async def notify(self, state: State, resource_attributes: AttributeDict) -> None:
state = str(state)
logger.debug(f"Drone: {str(resource_attributes)} has changed state to {state}")
bind_parameters = dict(state=state)
bind_parameters = {"state": state}
bind_parameters.update(resource_attributes)
await self._dispatch_on_state.get(state, self.update_resource)(bind_parameters)

async def update_resource(self, bind_parameters: dict) -> None:
async def update_resource(self, bind_parameters: Dict) -> None:
sql_query = """UPDATE Resources SET updated = :updated,
state_id = (SELECT state_id FROM ResourceStates WHERE state = :state)
WHERE drone_uuid = :drone_uuid
Expand Down
14 changes: 12 additions & 2 deletions tests/plugins_t/test_sqliteregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def test_add_machine_types(self):
self.registry.add_site(site_name)
self.registry.add_machine_types(site_name, self.test_machine_type)

# Database content has to be checked several times
# Define inline function to re-use code
def check_db_content():
machine_types = self.execute_db_query(
sql_query="""SELECT MachineTypes.machine_type, Sites.site_name
Expand Down Expand Up @@ -141,6 +143,8 @@ def test_add_site(self):
test_site_names = (self.test_site_name, self.other_test_site_name)
self.registry.add_site(test_site_names[0])

# Database content has to be checked several times
# Define inline function to re-use code
def check_db_content():
for row, site_name in zip(
self.execute_db_query("SELECT site_name FROM Sites"), test_site_names
Expand Down Expand Up @@ -189,6 +193,8 @@ def test_get_resources(self):

@patch("tardis.plugins.sqliteregistry.logging", Mock())
def test_notify(self):
# Database has to be queried multiple times
# Define inline function to re-use code
def fetch_all():
return self.execute_db_query(
sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state,
Expand Down Expand Up @@ -227,6 +233,8 @@ def fetch_all():
self.assertListEqual([], fetch_all())

def test_insert_resources(self):
# Database has to be queried multiple times
# Define inline function to re-use code
def fetch_all():
return self.execute_db_query(
sql_query="""SELECT R.remote_resource_uuid, R.drone_uuid, RS.state,
Expand All @@ -242,7 +250,8 @@ def fetch_all():
self.registry.add_site(site_name)
self.registry.add_machine_types(site_name, self.test_machine_type)

bind_parameters = dict(state="BootingState", **self.test_resource_attributes)
bind_parameters = {"state": "BootingState"}
bind_parameters.update(self.test_resource_attributes)

run_async(self.registry.insert_resource, bind_parameters)

Expand All @@ -255,7 +264,8 @@ def fetch_all():
self.assertListEqual([self.test_notify_result], fetch_all())

# Test same remote_resource_uuids on different sites
bind_parameters = dict(state="BootingState", **self.test_resource_attributes)
bind_parameters = {"state": "BootingState"}
bind_parameters.update(self.test_resource_attributes)
bind_parameters["drone_uuid"] = f"{self.other_test_site_name}-045285abef1"
bind_parameters["site_name"] = self.other_test_site_name

Expand Down

0 comments on commit 4c26b80

Please sign in to comment.