Skip to content

Commit

Permalink
Improve get_channel logic (#112)
Browse files Browse the repository at this point in the history
* Improve get_channel logic

* Update fixtures

* Style changes

Co-authored-by: jack1142 <[email protected]>
  • Loading branch information
Drapersniper and Jackenmen authored Mar 3, 2022
1 parent 915060b commit 5649f70
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
1 change: 0 additions & 1 deletion lavalink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"NodeStats",
"Stats",
"user_id",
"channel_finder_func",
"Player",
"PlayerManager",
"initialize",
Expand Down
1 change: 0 additions & 1 deletion lavalink/lavalink.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ async def initialize(
_loop = bot.loop

player_manager.user_id = bot.user.id
player_manager.channel_finder_func = bot.get_channel
register_event_listener(_handle_event)
register_update_listener(_handle_update)

Expand Down
37 changes: 26 additions & 11 deletions lavalink/player_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import datetime
from random import shuffle
from typing import KeysView, Optional, TYPE_CHECKING, ValuesView
from typing import KeysView, Optional, Tuple, TYPE_CHECKING, ValuesView

import discord
from discord.backoff import ExponentialBackoff
Expand All @@ -13,10 +13,9 @@
if TYPE_CHECKING:
from . import node

__all__ = ["user_id", "channel_finder_func", "Player", "PlayerManager"]
__all__ = ["user_id", "Player", "PlayerManager"]

user_id = None
channel_finder_func = lambda channel_id: None


class Player(RESTClient):
Expand Down Expand Up @@ -489,15 +488,20 @@ def get_player(self, guild_id: int) -> Player:
return self._player_dict[guild_id]
raise KeyError("No such player for that guild.")

def _ensure_player(self, channel_id: int):
channel = channel_finder_func(channel_id)
def _ensure_player(
self, guild_id: int, channel_id: int
) -> Optional[Tuple[Player, discord.TextChannel]]:
guild: discord.Guild = self.bot.get_guild(guild_id)
if guild is None:
return None
channel = guild.get_channel(channel_id)
if channel is not None:
try:
p = self.get_player(channel.guild.id)
p = self.get_player(guild_id)
except KeyError:
log.debug("Received voice channel connection without a player.")
p = Player(self, channel)
self._player_dict[channel.guild.id] = p
self._player_dict[guild_id] = p
return p, channel

async def _remove_player(self, guild_id: int):
Expand Down Expand Up @@ -564,10 +568,21 @@ async def on_socket_response(self, data):

else:
# After initial connection, get session ID
p, channel = self._ensure_player(int(channel_id))
if channel != p.channel:
if p.channel:
p._last_channel_id = p.channel.id
response = self._ensure_player(int(guild_id), int(channel_id))
if response is None:
# We disconnected
p = self._player_dict.get(guild_id)
msg = "Received voice disconnect from discord, removing player."
if p:
msg += f" {p}"
ws_rll_log.info(msg)
self.voice_states[guild_id] = {}
await self._remove_player(int(guild_id))
else:
p, channel = response
if channel != p.channel:
if p.channel:
p._last_channel_id = p.channel.id
p.channel = channel

session_id = data["d"]["session_id"]
Expand Down
43 changes: 33 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
from types import SimpleNamespace
from typing import NamedTuple

import pytest
import asyncio
Expand Down Expand Up @@ -51,26 +52,48 @@ def closed(self):
return self._closed


@pytest.fixture
class Guild:
def __init__(self, id, name):
self.id = id
self.name = name


class VoiceChannel:
def __init__(self, id, name):
self.id = id
self.name = name


@pytest.fixture(scope="session")
def user():
User = namedtuple("User", "id")
return User(1234567890)


@pytest.fixture
def guild():
Guild = namedtuple("Guild", "id name")
@pytest.fixture(scope="session")
def _guild():
return Guild(987654321, "Testing")


@pytest.fixture()
def voice_channel(guild):
VoiceChannel = namedtuple("VoiceChannel", "id guild name")
return VoiceChannel(9999999999, guild, "Testing VC")
@pytest.fixture(scope="session")
def _voice_channel():
return VoiceChannel(9999999999, "Testing VC")


@pytest.fixture(scope="session")
def guild(_guild, _voice_channel):
_guild.get_channel = lambda channel_id: _voice_channel
return _guild


@pytest.fixture(scope="session")
def voice_channel(guild, _voice_channel):
_voice_channel.guild = guild
return _voice_channel


@pytest.fixture
async def bot(event_loop, user, voice_channel):
async def bot(event_loop, user, guild, voice_channel):
async def voice_state(guild_id=None, channel_id=None):
pass

Expand All @@ -90,7 +113,7 @@ def closed(self):
bot_.loop = event_loop
bot_._connection = conn
bot_.user = user
bot_.get_channel = lambda channel_id: voice_channel
bot_.get_guild = lambda guild_id: guild
bot_.shard_count = 1

yield bot_
Expand Down
2 changes: 0 additions & 2 deletions tests/test_lavalink.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ async def test_initialize(bot):
await lavalink.initialize(bot, "localhost", "password", 2333, 2333)

assert lavalink.player_manager.user_id == bot.user.id
assert lavalink.player_manager.channel_finder_func == bot.get_channel

assert len(lavalink.node._nodes) == bot.shard_count

bot.add_listener.assert_called()

0 comments on commit 5649f70

Please sign in to comment.