Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jwt token support #33

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions pyhon/connection/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import re
import secrets
import urllib
import base64
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any, List
from urllib import parse
from urllib.parse import quote
Expand Down Expand Up @@ -37,10 +38,26 @@ class HonAuthData:
cognito_token: str = ""
id_token: str = ""

def decode_jwt(token: str) -> dict[str, str]:
if token == "":
return {}

def base64url_decode(input_str: str) -> bytes:
# Add padding if necessary
input_str += '=' * (4 - len(input_str) % 4)
return base64.urlsafe_b64decode(input_str)

# Split the token into parts
_, payload_b64, _ = token.split('.')

if not payload_b64:
raise Exception("Invalid JWT token!")

return json.loads(base64url_decode(payload_b64))


class HonAuth:
_TOKEN_EXPIRES_AFTER_HOURS = 8
_TOKEN_EXPIRE_WARNING_HOURS = 7
_TOKEN_EXPIRE_WARNING_HOURS = 1

def __init__(
self,
Expand All @@ -55,9 +72,15 @@ def __init__(
self._login_data.email = email
self._login_data.password = password
self._device = device
self._expires: datetime = datetime.utcnow()
self._expires: datetime = datetime.now(timezone.utc)
self._auth = HonAuthData()

@property
def expires(self) -> datetime:
if self.id_token == "":
return datetime.fromtimestamp(0, timezone.utc)
return datetime.fromtimestamp(float(decode_jwt(self.id_token).get("exp", "0")), timezone.utc)

@property
def cognito_token(self) -> str:
return self._auth.cognito_token
Expand All @@ -74,12 +97,12 @@ def access_token(self) -> str:
def refresh_token(self) -> str:
return self._auth.refresh_token

def _check_token_expiration(self, hours: int) -> bool:
return datetime.utcnow() >= self._expires + timedelta(hours=hours)
def _check_token_expiration(self, hours: int = 0) -> bool:
return datetime.now(timezone.utc) >= self.expires - timedelta(hours=hours)

@property
def token_is_expired(self) -> bool:
return self._check_token_expiration(self._TOKEN_EXPIRES_AFTER_HOURS)
return self._check_token_expiration()

@property
def token_expires_soon(self) -> bool:
Expand Down Expand Up @@ -119,7 +142,7 @@ async def _introduce(self) -> str:
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}"
async with self._request.get(url) as response:
text = await response.text()
self._expires = datetime.utcnow()
self._expires = datetime.now(timezone.utc)
login_url: List[str] = re.findall("(?:url|href) ?= ?'(.+?)'", text)
if not login_url:
if "oauth/done#access_token=" in text:
Expand Down
2 changes: 1 addition & 1 deletion pyhon/connection/handler/hon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def create(self) -> Self:
return self

async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
if self._refresh_token:
if self._refresh_token and self.auth.token_expires_soon:
await self.auth.refresh(self._refresh_token)
if not (self.auth.cognito_token and self.auth.id_token):
await self.auth.authenticate()
Expand Down
2 changes: 1 addition & 1 deletion pyhon/parameter/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def value(self, value: str) -> None:
self._value = value
self.check_trigger(value)
else:
raise ValueError(f"Allowed values: {self._values} But was: {value}")
raise ValueError(f"Allowed values for {self.key}: {self._values} But was: {value}")