Skip to content

Commit

Permalink
Migrate to aiohttp for stability on resource fetching & bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mhdzumair committed Nov 1, 2024
1 parent fe87124 commit a87bab3
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 117 deletions.
3 changes: 2 additions & 1 deletion streaming_providers/alldebrid/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def initialize_headers(self):
async def disable_access_token(self):
pass

async def _handle_service_specific_errors(self, error):
async def _handle_service_specific_errors(self, error_data: dict, status_code: int):
pass

async def _make_request(
Expand All @@ -30,6 +30,7 @@ async def _make_request(
params: Optional[dict] = None,
is_return_none: bool = False,
is_expected_to_fail: bool = False,
retry_count: int = 0,
) -> dict:
params = params or {}
params["agent"] = self.AGENT
Expand Down
152 changes: 93 additions & 59 deletions streaming_providers/debrid_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import traceback
from abc import abstractmethod
from base64 import b64encode, b64decode
from typing import Optional
from typing import Optional, Dict, Union

import httpx
import aiohttp
from aiohttp import ClientResponse, ClientTimeout, ContentTypeError

from streaming_providers.exceptions import ProviderException

Expand All @@ -13,10 +14,18 @@ class DebridClient:
def __init__(self, token: Optional[str] = None):
self.token = token
self.is_private_token = False
self.headers = {}
self.client: httpx.AsyncClient = httpx.AsyncClient(
timeout=18.0
) # Stremio timeout is 20s
self.headers: Dict[str, str] = {}
self._session: Optional[aiohttp.ClientSession] = None
self._timeout = ClientTimeout(total=15) # Stremio timeout is 20s

@property
def session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession(
timeout=self._timeout,
connector=aiohttp.TCPConnector(ttl_dns_cache=300),
)
return self._session

async def __aenter__(self):
await self.initialize_headers()
Expand All @@ -28,7 +37,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disable_access_token()
except ProviderException:
pass
await self.client.aclose()

if self._session:
await self._session.close()
self._session = None

async def _make_request(
self,
Expand All @@ -39,73 +51,95 @@ async def _make_request(
params: Optional[dict] = None,
is_return_none: bool = False,
is_expected_to_fail: bool = False,
) -> dict | list:
retry_count: int = 0,
) -> dict | list | str:
try:
response = await self.client.request(
async with self.session.request(
method, url, data=data, json=json, params=params, headers=self.headers
)
response.raise_for_status()
return await self._parse_response(response, is_return_none)
except httpx.RequestError as error:
await self._handle_request_error(error, is_expected_to_fail)
except httpx.HTTPStatusError as error:
return await self._handle_http_error(error, is_expected_to_fail)
) as response:
await self._check_response_status(response, is_expected_to_fail)
return await self._parse_response(
response, is_return_none, is_expected_to_fail
)

except aiohttp.ClientConnectorError as error:
if retry_count < 1: # Try one more time
return await self._make_request(
method,
url,
data=data,
json=json,
params=params,
is_return_none=is_return_none,
is_expected_to_fail=is_expected_to_fail,
retry_count=retry_count + 1,
)
await self._handle_request_error(error)
except aiohttp.ClientError as error:
await self._handle_request_error(error)
except Exception as error:
await self._handle_request_error(error, is_expected_to_fail)

@staticmethod
async def _handle_request_error(error: Exception, is_expected_to_fail: bool):
if isinstance(error, httpx.TimeoutException):
raise ProviderException("Request timed out.", "torrent_not_downloaded.mp4")
elif isinstance(error, httpx.TransportError):
raise ProviderException(
"Failed to connect to Debrid service.", "debrid_service_down_error.mp4"
)
elif not is_expected_to_fail:
raise ProviderException(f"Request error: {str(error)}", "api_error.mp4")
await self._handle_request_error(error)

async def _handle_http_error(
self, error: httpx.HTTPStatusError, is_expected_to_fail: bool
async def _check_response_status(
self, response: ClientResponse, is_expected_to_fail: bool
):
if error.response.status_code in [502, 503, 504]:
"""Check response status and handle HTTP errors."""
try:
response.raise_for_status()
except aiohttp.ClientResponseError as error:
if error.status in [502, 503, 504]:
raise ProviderException(
"Debrid service is down.", "debrid_service_down_error.mp4"
)
if is_expected_to_fail:
return

if response.headers.get("Content-Type") == "application/json":
error_content = await response.json()
await self._handle_service_specific_errors(error_content, error.status)
else:
error_content = await response.text()

if error.status == 401:
raise ProviderException("Invalid token", "invalid_token.mp4")

formatted_traceback = "".join(traceback.format_exception(error))
raise ProviderException(
"Debrid service is down.", "debrid_service_down_error.mp4"
f"API Error {error_content} \n{formatted_traceback}",
"api_error.mp4",
)

if is_expected_to_fail:
return (
error.response.json()
if error.response.headers.get("Content-Type") == "application/json"
else error.response.text
@staticmethod
async def _handle_request_error(error: Exception):
if isinstance(error, asyncio.TimeoutError):
raise ProviderException("Request timed out.", "torrent_not_downloaded.mp4")
elif isinstance(error, aiohttp.ClientConnectorError):
raise ProviderException(
"Failed to connect to Debrid service.", "debrid_service_down_error.mp4"
)

await self._handle_service_specific_errors(error)

if error.response.status_code == 401:
raise ProviderException("Invalid token", "invalid_token.mp4")

formatted_traceback = "".join(traceback.format_exception(error))
raise ProviderException(
f"API Error {error.response.text} \n{formatted_traceback}",
"api_error.mp4",
)
raise ProviderException(f"Request error: {str(error)}", "api_error.mp4")

@abstractmethod
async def _handle_service_specific_errors(self, error: httpx.HTTPStatusError):
async def _handle_service_specific_errors(self, error_data: dict, status_code: int):
"""
Service specific errors on api requests.
"""
raise NotImplementedError

@staticmethod
async def _parse_response(response: httpx.Response, is_return_none: bool):
async def _parse_response(
response: ClientResponse, is_return_none: bool, is_expected_to_fail: bool
) -> Union[dict, list, str]:
if is_return_none:
return {}
try:
return response.json()
except ValueError as error:
return await response.json()
except (ValueError, ContentTypeError) as error:
response_text = await response.text()
if is_expected_to_fail:
return response_text
raise ProviderException(
f"Failed to parse response error: {error}. \nresponse: {response.text}",
f"Failed to parse response error: {error}. \nresponse: {response_text}",
"api_error.mp4",
)

Expand All @@ -120,11 +154,11 @@ async def disable_access_token(self):
async def wait_for_status(
self,
torrent_id: str,
target_status: str | int,
target_status: Union[str, int],
max_retries: int,
retry_interval: int,
torrent_info: dict | None = None,
):
torrent_info: Optional[dict] = None,
) -> dict:
"""Wait for the torrent to reach a particular status."""
# if torrent_info is available, check the status from it
if torrent_info:
Expand All @@ -142,7 +176,7 @@ async def wait_for_status(
)

@abstractmethod
async def get_torrent_info(self, torrent_id):
async def get_torrent_info(self, torrent_id: str) -> dict:
raise NotImplementedError

@staticmethod
Expand All @@ -151,10 +185,10 @@ def encode_token_data(code: str, *args, **kwargs) -> str:
return b64encode(token.encode()).decode()

@staticmethod
def decode_token_str(token: str) -> str | None:
def decode_token_str(token: str) -> Optional[str]:
try:
_, code = b64decode(token).decode().split(":")
except (ValueError, UnicodeDecodeError):
# Assume as private token
return
return None
return code
9 changes: 3 additions & 6 deletions streaming_providers/debridlink/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from base64 import b64encode, b64decode
from typing import Any, Optional

from streaming_providers.debrid_client import DebridClient
Expand Down Expand Up @@ -33,10 +32,8 @@ def _handle_error_message(error_message):
"ip_not_allowed.mp4",
)

async def _handle_service_specific_errors(self, error):
if error.response.headers.get("content-type") == "application/json":
error_message = error.response.json().get("error")
self._handle_error_message(error_message)
async def _handle_service_specific_errors(self, error_data: dict, status_code: int):
self._handle_error_message(error_data.get("error"))

async def initialize_headers(self):
if self.token:
Expand Down Expand Up @@ -106,7 +103,7 @@ async def add_magnet_link(self, magnet_link):
is_expected_to_fail=True,
)
if response.get("error"):
await self._handle_error_message(response.get("error"))
self._handle_error_message(response.get("error"))
raise ProviderException(
f"Failed to add magnet link to Debrid-Link: {response.get('error')}",
"transfer_error.mp4",
Expand Down
56 changes: 39 additions & 17 deletions streaming_providers/offcloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from os import path
from typing import Optional, List

import httpx
import aiohttp


from db.models import TorrentStreams
from streaming_providers.debrid_client import DebridClient
Expand All @@ -23,10 +24,10 @@ async def initialize_headers(self):
async def disable_access_token(self):
pass

async def _handle_service_specific_errors(self, error: httpx.HTTPStatusError):
if error.response.status_code == 403:
async def _handle_service_specific_errors(self, error_data: dict, status_code: int):
if status_code == 403:
raise ProviderException("Invalid OffCloud API key", "invalid_token.mp4")
if error.response.status_code == 429:
if status_code == 429:
raise ProviderException(
"OffCloud rate limit exceeded", "too_many_requests.mp4"
)
Expand All @@ -40,6 +41,7 @@ async def _make_request(
params: Optional[dict] = None,
is_return_none: bool = False,
is_expected_to_fail: bool = False,
retry_count: int = 0,
delete: bool = False,
) -> dict | list:
params = params or {}
Expand Down Expand Up @@ -87,7 +89,7 @@ async def get_available_torrent(self, info_hash: str) -> Optional[dict]:
(
torrent
for torrent in available_torrents
if info_hash.casefold() in torrent["originalLink"].casefold()
if info_hash.casefold() in torrent.get("originalLink", "").casefold()
),
None,
)
Expand All @@ -96,18 +98,38 @@ async def explore_folder_links(self, request_id: str) -> List[str]:
return await self._make_request("GET", f"/cloud/explore/{request_id}")

async def update_file_sizes(self, files_data: list[dict]):
responses = await asyncio.gather(
*[
self.client.head(file_data["link"], timeout=5)
for file_data in files_data
],
return_exceptions=True,
)
for file_data, response in zip(files_data, responses):
if isinstance(response, Exception):
continue
if response.status_code == 200:
file_data["size"] = int(response.headers.get("Content-Length", 0))
"""
Update file sizes for a list of files by making HEAD requests.
Args:
files_data (list[dict]): List of file data dictionaries containing 'link' keys
Note:
This method modifies the input files_data list in-place, adding 'size' keys
where the HEAD request was successful.
"""

async def get_file_size(file_data: dict) -> tuple[dict, Optional[int]]:
"""Helper function to get file size for a single file."""
try:
async with self.session.head(
file_data["link"],
timeout=aiohttp.ClientTimeout(total=5),
allow_redirects=True,
) as response:
if response.status == 200:
return file_data, int(response.headers.get("Content-Length", 0))
except (aiohttp.ClientError, asyncio.TimeoutError):
pass
return file_data, 0

# Gather all HEAD requests with proper concurrency
tasks = [get_file_size(file_data) for file_data in files_data]
results = await asyncio.gather(*tasks, return_exceptions=False)

# Update file sizes in the original data
for file_data, size in results:
file_data["size"] = size

async def create_download_link(
self,
Expand Down
3 changes: 2 additions & 1 deletion streaming_providers/offcloud/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,6 @@ async def delete_all_torrents_from_oc(user_data: UserData, **kwargs):
async with OffCloud(token=user_data.streaming_provider.token) as oc_client:
torrents = await oc_client.get_user_torrent_list()
await asyncio.gather(
*[oc_client.delete_torrent(torrent["requestId"]) for torrent in torrents]
*[oc_client.delete_torrent(torrent["requestId"]) for torrent in torrents],
return_exceptions=True,
)
Loading

0 comments on commit a87bab3

Please sign in to comment.