Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat adding async session support #4244

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion telethon/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ async def log_out(self: 'TelegramClient') -> bool:
self._authorized = False

await self.disconnect()
self.session.delete()
delete = self.session.delete()
if inspect.isawaitable(delete):
await delete
self.session = None
return True

Expand Down
9 changes: 6 additions & 3 deletions telethon/client/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ async def _init(
config = await self.client(functions.help.GetConfigRequest())
for option in config.dc_options:
if option.ip_address == self.client.session.server_address:
self.client.session.set_dc(
option.id, option.ip_address, option.port)
self.client.session.save()
set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port)
if inspect.isawaitable(set_dc):
await set_dc
save = self.client.session.save()
if inspect.isawaitable(save):
await save
break

# TODO Figure out why the session may have the wrong DC ID
Expand Down
123 changes: 84 additions & 39 deletions telethon/client/telegrambaseclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import inspect
import re
import asyncio
import collections
Expand Down Expand Up @@ -302,15 +303,6 @@ def __missing__(self, key):
'The given session must be a str or a Session instance.'
)

# ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or
(':' in session.server_address) != use_ipv6):
session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)

self.flood_sleep_threshold = flood_sleep_threshold

# TODO Use AsyncClassWrapper(session)
Expand Down Expand Up @@ -445,19 +437,7 @@ def __missing__(self, key):
self._message_box = MessageBox(self._log['messagebox'])
self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference)
self._entity_cache_limit = entity_cache_limit

self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect
)

self._sender = None

# endregion

Expand Down Expand Up @@ -541,6 +521,30 @@ async def connect(self: 'TelegramClient') -> None:
elif self._loop != helpers.get_running_loop():
raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)')

# ':' in session.server_address is True if it's an IPv6 address
if (not self.session.server_address or
(':' in self.session.server_address) != self._use_ipv6):
set_dc = self.session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
if inspect.isawaitable(set_dc):
await set_dc

if not self._sender:
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect
)

if not await self._sender.connect(self._connection(
self.session.server_address,
self.session.port,
Expand All @@ -553,12 +557,19 @@ async def connect(self: 'TelegramClient') -> None:
return

self.session.auth_key = self._sender.auth_key
self.session.save()
save = self.session.save()
if inspect.isawaitable(save):
await save

try:
# See comment when saving entities to understand this hack
self_id = self.session.get_input_entity(0).access_hash
self_entity = self.session.get_input_entity(0)
if inspect.isawaitable(self_entity):
self_entity = await self_entity
self_id = self_entity.access_hash
self_user = self.session.get_input_entity(self_id)
if inspect.isawaitable(self_user):
self_user = await self_user
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
except ValueError:
pass
Expand All @@ -567,7 +578,10 @@ async def connect(self: 'TelegramClient') -> None:
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = []

for entity_id, state in self.session.get_update_states():
update_states = self.session.get_update_states()
if inspect.isawaitable(update_states):
update_states = await update_states
for entity_id, state in update_states:
if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
Expand All @@ -578,6 +592,8 @@ async def connect(self: 'TelegramClient') -> None:
for state in cs:
try:
entity = self.session.get_input_entity(state.channel_id)
if inspect.isawaitable(entity):
entity = await entity
except ValueError:
self._log[__name__].warning(
'No access_hash in cache for channel %s, will not catch up', state.channel_id)
Expand Down Expand Up @@ -681,23 +697,37 @@ def set_proxy(self: 'TelegramClient', proxy: typing.Union[tuple, dict]):
else:
connection._proxy = proxy

def _save_states_and_entities(self: 'TelegramClient'):
async def _save_states_and_entities(self: 'TelegramClient'):
entities = self._mb_entity_cache.get_all_entities()

# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats.
self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []))
process_entities = self.session.process_entities(
types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], [])
)
if inspect.isawaitable(process_entities):
await process_entities

# As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``.
# This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved.
if self._mb_entity_cache.self_id:
self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []))
process_entities = self.session.process_entities(
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
)
if inspect.isawaitable(process_entities):
await process_entities

ss, cs = self._message_box.session_state()
self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
if inspect.isawaitable(update_state):
await update_state
now = datetime.datetime.now() # any datetime works; channels don't need it
for channel_id, pts in cs.items():
self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0))
update_state = self.session.set_update_state(
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
)
if inspect.isawaitable(update_state):
await update_state

async def _disconnect_coro(self: 'TelegramClient'):
if self.session is None:
Expand Down Expand Up @@ -729,9 +759,11 @@ async def _disconnect_coro(self: 'TelegramClient'):
await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear()

self._save_states_and_entities()
await self._save_states_and_entities()

self.session.close()
close = self.session.close()
if inspect.isawaitable(close):
await close

async def _disconnect(self: 'TelegramClient'):
"""
Expand All @@ -740,7 +772,8 @@ async def _disconnect(self: 'TelegramClient'):
file; user disconnects however should close it since it means that
their job with the client is complete and we should clean it up all.
"""
await self._sender.disconnect()
if self._sender:
await self._sender.disconnect()
await helpers._cancel(self._log[__name__],
updates_handle=self._updates_handle,
keepalive_handle=self._keepalive_handle)
Expand All @@ -749,25 +782,33 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
"""
Permanently switches the current connection to the new data center.
"""
if not self._sender:
raise RuntimeError('Cant switch dc if not connected')
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc)

self.session.set_dc(dc.id, dc.ip_address, dc.port)
set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(set_dc):
await set_dc
# auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None
self.session.auth_key = None
self.session.save()
save = self.session.save()
if inspect.isawaitable(save):
await save
await self._disconnect()
return await self.connect()

def _auth_key_callback(self: 'TelegramClient', auth_key):
async def _auth_key_callback(self: 'TelegramClient', auth_key):
"""
Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized.
"""
self.session.auth_key = auth_key
self.session.save()
save = self.session.save()
if inspect.isawaitable(save):
await save

# endregion

Expand Down Expand Up @@ -892,7 +933,11 @@ async def _get_cdn_client(self: 'TelegramClient', cdn_redirect):
if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(session):
session = await session
set_dc = session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(set_dc):
await set_dc
self._exported_sessions[cdn_redirect.dc_id] = session

self._log[__name__].info('Creating new CDN client')
Expand All @@ -907,7 +952,7 @@ async def _get_cdn_client(self: 'TelegramClient', cdn_redirect):
# We won't be calling GetConfigRequest because it's only called
# when needed by ._get_dc, and also it's static so it's likely
# set already. Avoid invoking non-CDN methods by not syncing updates.
client.connect(_sync_updates=False)
await client.connect(_sync_updates=False)
return client

# endregion
Expand Down
8 changes: 5 additions & 3 deletions telethon/client/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ async def _update_loop(self: 'TelegramClient'):
len(self._mb_entity_cache),
self._entity_cache_limit
)
self._save_states_and_entities()
await self._save_states_and_entities()
self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map)
if len(self._mb_entity_cache) >= self._entity_cache_limit:
warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit')
Expand Down Expand Up @@ -514,9 +514,11 @@ async def _keepalive_loop(self: 'TelegramClient'):
# inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new.
self._save_states_and_entities()
await self._save_states_and_entities()

self.session.save()
save = self.session.save()
if inspect.isawaitable(save):
await save

async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this
Expand Down
20 changes: 15 additions & 5 deletions telethon/client/users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import inspect
import itertools
import time
import typing
Expand Down Expand Up @@ -75,7 +76,9 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
exceptions.append(e)
results.append(None)
continue
self.session.process_entities(result)
process_entities = self.session.process_entities(result)
if inspect.isawaitable(process_entities):
await process_entities
exceptions.append(None)
results.append(result)
request_index += 1
Expand All @@ -85,7 +88,9 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
return results
else:
result = await future
self.session.process_entities(result)
process_entities = self.session.process_entities(result)
if inspect.isawaitable(process_entities):
await process_entities
return result
except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError, errors.InterdcCallErrorError,
Expand Down Expand Up @@ -428,7 +433,10 @@ async def get_input_entity(

# No InputPeer, cached peer, or known string. Fetch from disk cache
try:
return self.session.get_input_entity(peer)
input_entity = self.session.get_input_entity(peer)
if inspect.isawaitable(input_entity):
input_entity = await input_entity
return input_entity
except ValueError:
pass

Expand Down Expand Up @@ -567,8 +575,10 @@ async def _get_entity_from_string(self: 'TelegramClient', string):
pass
try:
# Nobody with this username, maybe it's an exact name/title
return await self.get_entity(
self.session.get_input_entity(string))
input_entity = self.session.get_input_entity(string)
if inspect.isawaitable(input_entity):
input_entity = await input_entity
return await self.get_entity(input_entity)
except ValueError:
pass

Expand Down
2 changes: 1 addition & 1 deletion telethon/network/mtprotosender.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ async def _try_gen_auth_key(self, attempt):
# notify whenever we change it. This is crucial when we
# switch to different data centers.
if self._auth_key_callback:
self._auth_key_callback(self.auth_key)
await self._auth_key_callback(self.auth_key)

self._log.debug('auth_key generation success!')
return True
Expand Down