Skip to content

Commit

Permalink
Merge pull request scitokens#196 from djw8605/vynpt-master
Browse files Browse the repository at this point in the history
Vynpt master
  • Loading branch information
djw8605 authored May 31, 2024
2 parents fc20c15 + c171a73 commit 9bfbbba
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 83 deletions.
5 changes: 3 additions & 2 deletions src/scitokens/tools/admin_add_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@ def add_args():
parser = argparse.ArgumentParser(description='Remove a local cached token')
parser.add_argument('issuer', help='issuer')
parser.add_argument('key_id', help='key_id')
parser.add_argument('-f', '--force', action='store_true', help='Force add')
args = parser.parse_args()
return args

def main():
args = add_args()
keycache = KeyCache()
res = keycache.add_key(args.issuer, args.key_id)
res = keycache.add_key(args.issuer, args.key_id, force_refresh=args.force)
if res != None:
print("Successfully added token with issuer = {} and key_id = {}!".format(args.issuer, args.key_id))
print(res)
else:
print("Invalid issuer = {} and key_id = {} !".format(args.issuer, args.key_id))
print("Cannot add issuer = {} and key_id = {} !".format(args.issuer, args.key_id))

if __name__ == "__main__":
main()
207 changes: 130 additions & 77 deletions src/scitokens/utils/keycache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import re
import logging
from urllib.error import URLError

try:
import urllib.request as request
Expand Down Expand Up @@ -75,21 +76,25 @@ def addkeyinfo(self, issuer, key_id, public_key, cache_timer=0, next_update=0):
if next_update == 0:
next_update = 3600

conn = sqlite3.connect(self.cache_location)
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = '{}' AND key_id = '{}'".format(issuer, key_id))
KeyCache._addkeyinfo(curs, issuer, key_id, public_key, cache_timer=cache_timer, next_update=next_update)
conn.commit()
conn.close()
try:
conn = sqlite3.connect(self.cache_location)
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = '{}' AND key_id = '{}'".format(issuer, key_id))
KeyCache._addkeyinfo(curs, issuer, key_id, public_key, cache_timer=cache_timer, next_update=next_update)
conn.commit()
conn.close()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')

@staticmethod
def _addkeyinfo(curs, issuer, key_id, public_key, cache_timer=0, next_update=0):
"""
Given an open database cursor to a key cache, insert a key.
"""
# Add the key to the cache
insert_key_statement = "INSERT INTO keycache VALUES('{issuer}', '{expiration}', '{key_id}', \
insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES('{issuer}', '{expiration}', '{key_id}', \
'{keydata}', '{next_update}')"
keydata = {
'pub_key': public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode('ascii'),
Expand Down Expand Up @@ -129,15 +134,74 @@ def _delete_cache_entry(self, issuer, key_id):
Delete a cache entry
"""
# Open the connection to the database
conn = sqlite3.connect(self.cache_location)
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = '{}' AND key_id = '{}'".format(issuer,
key_id))
conn.commit()
conn.close()
try:
conn = sqlite3.connect(self.cache_location)
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = '{}' AND key_id = '{}'".format(issuer,
key_id))
conn.commit()
conn.close()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')


def _add_negative_cache_entry(self, issuer, key_id, cache_retry_interval):
"""
Add a negative cache entry
"""
try:
conn = sqlite3.connect(self.cache_location)
conn.row_factory = sqlite3.Row
curs = conn.cursor()
insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES('{issuer}', '{expiration}', '{key_id}', \
'{keydata}', '{next_update}')"
keydata = ''
curs.execute(insert_key_statement.format(issuer=issuer, expiration=time.time()+cache_retry_interval, key_id=key_id,
keydata=keydata, next_update=time.time()+cache_retry_interval))
if curs.rowcount != 1:
logger = logging.getLogger("scitokens")
logger.error(UnableToWriteKeyCache("Unable to insert into key cache"))
conn.commit()
conn.close()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')


def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False):
def _download_and_add_key(self, issuer, key_id, insecure, force_refresh, cache_retry_interval):
"""
Download key data and add key (if possible)
"""
logger = logging.getLogger("scitokens")
try:
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
except ValueError as ex:
logger.error(ex)
raise ex
except URLError as ex:
logger.error("Unable to get key from issuer.\n{0}".format(str(ex)))
raise ex
except Exception as ex:
logger.error("No key was found in keycache and unable to get key: {0}".format(str(ex)))
# Create negative cache
if not force_refresh:
# If NOT forced, create negative cache
try:
self._add_negative_cache_entry(issuer, key_id, cache_retry_interval)
except Exception as ex:
logger.error(ex)
raise MissingKeyException(ex)

# Separate download and add key to avoid keycache deadlocks
try:
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
except Exception as ex:
logger.error("Unable to add new key data to keycache.\n{0}".format(ex))
return public_key


def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False, cache_retry_interval=300):
"""
Get the key information
Expand All @@ -146,22 +210,40 @@ def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False):
:param bool insecure: Whether insecure methods are acceptable (defaults to False).
:returns: None if no key is found. Else, returns the public key
"""
# Setup log configuration
logger = logging.getLogger("scitokens")

# Check the sql database
key_query = ("SELECT * FROM keycache WHERE "
"issuer = '{issuer}'")
if key_id != None:
key_query += " AND key_id = '{key_id}'"
conn = sqlite3.connect(self.cache_location)
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute(key_query.format(issuer=issuer, key_id=key_id))
row = None
try:
conn = sqlite3.connect(self.cache_location)
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute(key_query.format(issuer=issuer, key_id=key_id))
row = curs.fetchone()
conn.commit()
conn.close()
except Exception as ex:
logger.error(f'Keycache file is immutable. Detailed error: {ex}')

row = curs.fetchone()
conn.commit()
conn.close()
if row != None:
# Check if record is negative cache
if row['keydata'] == '':
# Negative Cache Handling
if not force_refresh and row['next_update'] > time.time():
logger.warning("Retry in {} seconds".format(int(row['next_update'] - time.time())))
return None
else:
# Force refresh or cache_retry_interval is over
self._delete_cache_entry(row['issuer'], row['key_id'])
row = None

# If it's time to update the key, but the key is still valid
if int(row['next_update']) < time.time() and self._check_validity(row):
if row and int(row['next_update']) < time.time() and self._check_validity(row):
# Try to update the key, but if it doesn't work, just return the saved one
try:
# Get the public key, probably from a webserver
Expand All @@ -171,78 +253,39 @@ def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False):
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
return public_key
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.warning("Unable to get key triggered by next update: {0}".format(str(ex)))
keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata'])
# Upgrade proof
if keydata:
return load_pem_public_key(keydata.encode(), backend=backends.default_backend())

# If it's not time to update the key, but the key is still valid
elif self._check_validity(row):
elif row and self._check_validity(row):
# If force_refresh is set, then update the key
if force_refresh:
try:
# update the keycache
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
return public_key
except ValueError as ex:
logging.exception("Unable to parse JSON stored in keycache. "
"This likely means the database format needs"
"to be updated, which we will now do automatically.\n{0}".format(str(ex)))
self._delete_cache_entry(issuer, key_id)
raise ex
except URLError as ex:
raise URLError("Unable to get key from issuer.\n{0}".format(str(ex)))
except MissingKeyException as ex:
raise MissingKeyException("Unable to force refresh key. \n{0}".format(str(ex)))

public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)

keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata'])
if keydata:
return load_pem_public_key(keydata.encode(), backend=backends.default_backend())

# update the keycache
try:
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
return public_key
except ValueError as ex:
logging.exception("Unable to parse JSON stored in keycache. "
"This likely means the database format needs"
"to be updated, which we will now do automatically.\n{0}".format(str(ex)))
self._delete_cache_entry(issuer, key_id)
raise ex
except URLError as ex:
raise URLError("Unable to get key from issuer.\n{0}".format(str(ex)))
except Exception as ex:
raise MissingKeyException("Key in keycache is expired and unable to get a new key.\n{0}".format(str(ex)))
# If local key not valid, update the keycache
public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)
return public_key


# If it's not time to update the key, and the key is not valid
else:
elif row:

# Delete the row
# If it gets to this point, then there is a row for the key, but it's:
# - Not valid anymore
self._delete_cache_entry(row['issuer'], row['key_id'])
# If key is a negative cache

# If it reaches here, then no key was found in the SQL
# Try checking the issuer (negative cache?)
try:
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
return public_key
except ValueError as ex:
logging.exception("Unable to parse JSON stored in keycache. "
"This likely means the database format needs"
"to be updated, which we will now do automatically.\n{0}".format(str(ex)))
self._delete_cache_entry(issuer, key_id)
raise ex
except URLError as ex:
raise URLError("Unable to get key from issuer.\n{0}".format(str(ex)))
except Exception as ex:
raise MissingKeyException("No key was found in keycache and unable to get key.\n{0}".format(str(ex)))
public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)
return public_key


@classmethod
Expand Down Expand Up @@ -358,7 +401,7 @@ def _get_cache_file(self):
2. $XDG_CACHE_HOME
3. .cache subdirectory of home directory as returned by the password database
"""

logger = logging.getLogger("scitokens")
config_cache_location = config.get('cache_location')
xdg_cache_home = os.environ.get("XDG_CACHE_HOME", None)
home_dir = os.path.expanduser("~")
Expand All @@ -374,14 +417,19 @@ def _get_cache_file(self):
try:
os.makedirs(cache_dir)
except OSError as ose:
raise UnableToCreateCache("Unable to create cache: {}".format(str(ose)))
# Unable to create a cache is not a fatal error
logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose)))
# If we couldn't create the cache directory, just return, nothing more to do here
return None

keycache_dir = os.path.join(cache_dir, "scitokens")
try:
if not os.path.exists(keycache_dir):
os.makedirs(keycache_dir)
except OSError as ose:
raise UnableToCreateCache("Unable to create cache: {}".format(str(ose)))
# Unable to create directories is not a fatal error
logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose)))
return None

keycache_file = os.path.join(keycache_dir, CACHE_FILENAME)
if not os.path.exists(keycache_file):
Expand Down Expand Up @@ -473,6 +521,11 @@ def update_all_keys(self, force_refresh=False):

res = []
for issuer, key_id in tokens:
updated = self.add_key(issuer, key_id, force_refresh=force_refresh)
res.append(updated)
try:
updated = self.add_key(issuer, key_id, force_refresh=force_refresh)
res.append(updated)
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error("Unable to update key: {0} {1}".format(issuer, key_id))
logger.error(ex)
return res
Loading

0 comments on commit 9bfbbba

Please sign in to comment.