Skip to content

Commit

Permalink
feat(users): added function to check if database is in use (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbaile authored Jan 7, 2025
1 parent 50aec86 commit d0771e1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
38 changes: 37 additions & 1 deletion src/nethsec/users/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import os
import subprocess
import secrets

from uci import UciExceptionNotFound

from nethsec import utils
from nethsec.ldif import LDIFParser
from urllib.parse import urlparse
Expand Down Expand Up @@ -436,7 +439,8 @@ def list_databases(uci):
elif uci.get('users', db, default='') == "ldap":
ret.append({"name": db, "type": "ldap", "description": uci.get('users', db, 'description', default=''),
"schema": uci.get('users', db, 'schema', default=''),
"uri": uci.get('users', db, 'uri', default='')})
"uri": uci.get('users', db, 'uri', default=''),
"used": used_by(uci, db)})
return ret

def get_database(uci, name):
Expand Down Expand Up @@ -871,3 +875,35 @@ def is_admin(uci, username):
if logins[l].get("username") == username:
return True
return False


def used_by(uci, database_name):
"""
Checks if the database is used by VPN or other services
Arguments:
- uci -- EUci pointer
- database_name -- Database identifier
Returns:
- dict containing the service that the database is used by
"""
results = []
try:
for instance in uci.get_all('openvpn'):
if uci.get('openvpn', instance, 'ns_user_db', default='') == database_name:
results.append('openvpn')
break
except UciExceptionNotFound:
pass

try:
for instance in uci.get_all('network'):
# could filter by proto = 'wireguard' but the performance is not an issue
if uci.get('network', instance, 'ns_user_db', default='') == database_name:
results.append('wireguard')
break
except UciExceptionNotFound:
pass

return results
4 changes: 2 additions & 2 deletions tests/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def test_list_databases(tmp_path):
db_list = users.list_databases(_setup_db(tmp_path))
assert {"name": "main", "description": "Main local database", "type": "local"} in db_list
assert {"name": "second", "description": "Secondary local database", "type": "local"} in db_list
assert {"name": "ldap1", "description": "Remote OpenLDAP server", "type": "ldap", "schema": "rfc2307", "uri": "ldaps://192.168.100.234"} in db_list
assert {"name": "ad1", "description": "Remote AD server", "type": "ldap", "schema": "ad", "uri": "ldaps://ad.nethserver.org"} in db_list
assert {"name": "ldap1", "description": "Remote OpenLDAP server", "type": "ldap", "schema": "rfc2307", "uri": "ldaps://192.168.100.234", "used": []} in db_list
assert {"name": "ad1", "description": "Remote AD server", "type": "ldap", "schema": "ad", "uri": "ldaps://ad.nethserver.org", "used": []} in db_list

def test_add_local_database(tmp_path):
u = _setup_db(tmp_path)
Expand Down

0 comments on commit d0771e1

Please sign in to comment.