Skip to content

Commit

Permalink
Change GuildBanIterator to match other lazy iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonxslays committed Apr 9, 2022
1 parent 4c1f264 commit 784a691
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 95 deletions.
20 changes: 11 additions & 9 deletions hikari/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5501,9 +5501,10 @@ async def fetch_ban(
def fetch_bans(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
/,
*,
before: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
after: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
newest_first: bool = False,
start_at: undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
) -> iterators.LazyIterator[guilds.GuildBan]:
"""Fetch the bans of a guild.
Expand All @@ -5515,14 +5516,15 @@ def fetch_bans(
Other Parameters
----------------
before : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishOr[hikari.users.PartialUser]]
If provided, filter to only actions before this snowflake or user.
after : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishOr[hikari.users.PartialUser]]
If provided, filter to only actions after this snowflake or user.
newest_first : builtins.bool
Whether to fetch the newest first or the oldest first.
!!! note
Bans will always be returned in ascending order by user ID.
If both before and after are provided, only before is respected.
Defaults to `builtins.False`.
start_at : undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]]
If provided, will start at this snowflake. If you provide
a datetime object, it will be transformed into a snowflake. This
may also be a scheduled event object object. In this case, the
date the object was first created will be used.
Returns
-------
Expand Down
18 changes: 11 additions & 7 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2768,16 +2768,20 @@ async def fetch_ban(
def fetch_bans(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
/,
*,
before: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
after: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
newest_first: bool = False,
start_at: undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED,
) -> iterators.LazyIterator[guilds.GuildBan]:
if start_at is undefined.UNDEFINED:
start_at = snowflakes.Snowflake.max() if newest_first else snowflakes.Snowflake.min()
elif isinstance(start_at, datetime.datetime):
start_at = snowflakes.Snowflake.from_datetime(start_at)
else:
start_at = int(start_at)

return special_endpoints_impl.GuildBanIterator(
entity_factory=self._entity_factory,
request_call=self._request,
guild=guild,
first_id=after,
last_id=before,
self._entity_factory, self._request, guild, newest_first, str(start_at)
)

async def fetch_roles(
Expand Down
15 changes: 9 additions & 6 deletions hikari/impl/special_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ class GuildBanIterator(iterators.BufferedLazyIterator["guilds.GuildBan"]):
"_request_call",
"_route",
"_first_id",
"_last_id",
"_newest_first",
)

def __init__(
Expand All @@ -631,21 +631,20 @@ def __init__(
..., typing.Coroutine[None, None, typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]]
],
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
first_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]],
last_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[users.PartialUser]],
newest_first: bool,
first_id: str,
) -> None:
super().__init__()
self._guild_id = snowflakes.Snowflake(str(int(guild)))
self._route = routes.GET_GUILD_BANS.compile(guild=guild)
self._request_call = request_call
self._entity_factory = entity_factory
self._first_id = first_id
self._last_id = last_id
self._newest_first = newest_first

async def _next_chunk(self) -> typing.Optional[typing.Generator[guilds.GuildBan, typing.Any, None]]:
query = data_binding.StringMapBuilder()
query.put("after", self._first_id)
query.put("before", self._last_id)
query.put("before" if self._newest_first else "after", self._first_id)
query.put("limit", 1000)

chunk = await self._request_call(compiled_route=self._route, query=query)
Expand All @@ -654,6 +653,10 @@ async def _next_chunk(self) -> typing.Optional[typing.Generator[guilds.GuildBan,
if not chunk:
return None

if self._newest_first:
# These are always returned in ascending order by `.user.id`.
chunk.reverse()

self._first_id = chunk[-1]["user"]["id"]
return (self._entity_factory.deserialize_guild_member_ban(b) for b in chunk)

Expand Down
101 changes: 45 additions & 56 deletions tests/hikari/impl/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,69 +1043,58 @@ def test_unban_member(self, rest_client):
assert reason is mock_unban_user.return_value
mock_unban_user.assert_called_once_with(123, 321, reason="ayaya")

def test_fetch_bans_when_before_is_undefined(self, rest_client):
guild = StubModel(123)
after = StubModel(789)
stub_iterator = mock.Mock()

with mock.patch.object(special_endpoints, "GuildBanIterator", return_value=stub_iterator) as iterator:
assert rest_client.fetch_bans(guild, after=after) == stub_iterator

iterator.assert_called_once_with(
entity_factory=rest_client._entity_factory,
request_call=rest_client._request,
guild=guild,
first_id=after,
last_id=undefined.UNDEFINED,
)

def test_fetch_bans_when_after_is_undefined(self, rest_client):
guild = StubModel(123)
before = StubModel(456)
stub_iterator = mock.Mock()
def test_fetch_bans(self, rest_client: rest.RESTClientImpl):
with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls:
iterator = rest_client.fetch_bans(187, newest_first=True, start_at=StubModel(65652342134))

with mock.patch.object(special_endpoints, "GuildBanIterator", return_value=stub_iterator) as iterator:
assert rest_client.fetch_bans(guild, before=before) == stub_iterator

iterator.assert_called_once_with(
entity_factory=rest_client._entity_factory,
request_call=rest_client._request,
guild=guild,
last_id=before,
first_id=undefined.UNDEFINED,
)
iterator_cls.assert_called_once_with(
rest_client._entity_factory,
rest_client._request,
187,
True,
"65652342134",
)
assert iterator is iterator_cls.return_value

def test_fetch_bans_when_before_and_after_are_undefined(self, rest_client):
guild = StubModel(123)
stub_iterator = mock.Mock()
def test_fetch_bans_when_datetime_for_start_at(self, rest_client: rest.RESTClientImpl):
start_at = datetime.datetime(2022, 3, 6, 12, 1, 58, 415625, tzinfo=datetime.timezone.utc)
with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls:
iterator = rest_client.fetch_bans(9000, newest_first=True, start_at=start_at)

with mock.patch.object(special_endpoints, "GuildBanIterator", return_value=stub_iterator) as iterator:
assert rest_client.fetch_bans(guild) == stub_iterator
iterator_cls.assert_called_once_with(
rest_client._entity_factory,
rest_client._request,
9000,
True,
"950000286338908160",
)
assert iterator is iterator_cls.return_value

iterator.assert_called_once_with(
entity_factory=rest_client._entity_factory,
request_call=rest_client._request,
guild=guild,
last_id=undefined.UNDEFINED,
first_id=undefined.UNDEFINED,
)
def test_fetch_bans_when_start_at_undefined(self, rest_client: rest.RESTClientImpl):
with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls:
iterator = rest_client.fetch_bans(8844)

def test_fetch_bans_when_before_and_after_are_provided(self, rest_client):
guild = StubModel(123)
before = StubModel(456)
after = StubModel(789)
stub_iterator = mock.Mock()
iterator_cls.assert_called_once_with(
rest_client._entity_factory,
rest_client._request,
8844,
False,
str(snowflakes.Snowflake.min()),
)
assert iterator is iterator_cls.return_value

with mock.patch.object(special_endpoints, "GuildBanIterator", return_value=stub_iterator) as iterator:
assert rest_client.fetch_bans(guild, before=before, after=after) == stub_iterator
def test_fetch_bans_when_start_at_undefined_and_newest_first(self, rest_client: rest.RESTClientImpl):
with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls:
iterator = rest_client.fetch_bans(3848, newest_first=True)

iterator.assert_called_once_with(
entity_factory=rest_client._entity_factory,
request_call=rest_client._request,
guild=guild,
last_id=before,
first_id=after,
)
iterator_cls.assert_called_once_with(
rest_client._entity_factory,
rest_client._request,
3848,
True,
str(snowflakes.Snowflake.max()),
)
assert iterator is iterator_cls.return_value

def test_command_builder(self, rest_client):
with warnings.catch_warnings():
Expand Down
93 changes: 76 additions & 17 deletions tests/hikari/impl/test_special_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,18 @@ async def test_aiter_when_empty_chunk(self, newest_first: bool):
class TestGuildBanIterator:
@pytest.mark.asyncio()
async def test_aiter(self):
mock_payload_1 = {"user": {"id": "123321123123"}}
mock_payload_2 = {"user": {"id": "123321123666"}}
mock_payload_3 = {"user": {"id": "123321124123"}}
mock_payload_4 = {"user": {"id": "123321124567"}}
mock_payload_5 = {"user": {"id": "12332112432234"}}
guild_id = 59320
expected_route = routes.GET_GUILD_BANS.compile(guild=10000)
mock_entity_factory = mock.Mock()
mock_payload_1 = {"user": {"id": "45234"}}
mock_payload_2 = {"user": {"id": "452745"}}
mock_payload_3 = {"user": {"id": "45237656"}}
mock_payload_4 = {"user": {"id": "452345666"}}
mock_payload_5 = {"user": {"id": "4523456744"}}
mock_result_1 = mock.Mock()
mock_result_2 = mock.Mock()
mock_result_3 = mock.Mock()
mock_result_4 = mock.Mock()
mock_result_5 = mock.Mock()
expected_route = routes.GET_GUILD_BANS.compile(guild=guild_id)
mock_entity_factory = mock.Mock()
mock_entity_factory.deserialize_guild_member_ban.side_effect = [
mock_result_1,
mock_result_2,
Expand All @@ -192,7 +191,11 @@ async def test_aiter(self):
side_effect=[[mock_payload_1, mock_payload_2, mock_payload_3], [mock_payload_4, mock_payload_5], []]
)
iterator = special_endpoints.GuildBanIterator(
mock_entity_factory, mock_request, guild=guild_id, first_id=720, last_id=undefined.UNDEFINED
entity_factory=mock_entity_factory,
request_call=mock_request,
guild=10000,
newest_first=False,
first_id="0",
)

result = await iterator
Expand All @@ -209,27 +212,83 @@ async def test_aiter(self):
)
mock_request.assert_has_awaits(
[
mock.call(compiled_route=expected_route, query={"after": "720", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"after": "123321124123", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"after": "12332112432234", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"after": "0", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"after": "45237656", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"after": "4523456744", "limit": "1000"}),
]
)

@pytest.mark.asyncio()
async def test_aiter_when_newest_first(self):
expected_route = routes.GET_GUILD_BANS.compile(guild=10000)
mock_entity_factory = mock.Mock()
mock_payload_1 = {"user": {"id": "432234"}}
mock_payload_2 = {"user": {"id": "1233211"}}
mock_payload_3 = {"user": {"id": "12332112"}}
mock_payload_4 = {"user": {"id": "1233"}}
mock_payload_5 = {"user": {"id": "54334"}}
mock_result_1 = mock.Mock()
mock_result_2 = mock.Mock()
mock_result_3 = mock.Mock()
mock_result_4 = mock.Mock()
mock_result_5 = mock.Mock()
mock_entity_factory.deserialize_guild_member_ban.side_effect = [
mock_result_1,
mock_result_2,
mock_result_3,
mock_result_4,
mock_result_5,
]
mock_request = mock.AsyncMock(
side_effect=[[mock_payload_1, mock_payload_2, mock_payload_3], [mock_payload_4, mock_payload_5], []]
)
iterator = special_endpoints.GuildBanIterator(
entity_factory=mock_entity_factory,
request_call=mock_request,
guild=10000,
newest_first=True,
first_id="321123321",
)

result = await iterator

assert result == [mock_result_1, mock_result_2, mock_result_3, mock_result_4, mock_result_5]
mock_entity_factory.deserialize_guild_member_ban.assert_has_calls(
[
mock.call(mock_payload_3),
mock.call(mock_payload_2),
mock.call(mock_payload_1),
mock.call(mock_payload_5),
mock.call(mock_payload_4),
]
)
mock_request.assert_has_awaits(
[
mock.call(compiled_route=expected_route, query={"before": "321123321", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"before": "432234", "limit": "1000"}),
mock.call(compiled_route=expected_route, query={"before": "1233", "limit": "1000"}),
]
)

@pytest.mark.parametrize("newest_first", [True, False])
@pytest.mark.asyncio()
async def test_aiter_when_empty_chunk(self):
guild_id = 88574
expected_route = routes.GET_GUILD_BANS.compile(guild=guild_id)
async def test_aiter_when_empty_chunk(self, newest_first: bool):
expected_route = routes.GET_GUILD_BANS.compile(guild=10000)
mock_entity_factory = mock.Mock()
mock_request = mock.AsyncMock(return_value=[])
iterator = special_endpoints.GuildBanIterator(
mock_entity_factory, mock_request, guild=guild_id, first_id=45, last_id=780
entity_factory=mock_entity_factory,
request_call=mock_request,
guild=10000,
newest_first=newest_first,
first_id="54234123123",
)

result = await iterator

assert result == []
mock_entity_factory.deserialize_guild_member_ban.assert_not_called()
query = {"before": "780", "after": "45", "limit": "1000"}
query = {"before" if newest_first else "after": "54234123123", "limit": "1000"}
mock_request.assert_awaited_once_with(compiled_route=expected_route, query=query)


Expand Down

0 comments on commit 784a691

Please sign in to comment.