Skip to content

Commit

Permalink
Improve typing for messages and message update methods (#783)
Browse files Browse the repository at this point in the history
* Improve typing for messages and message update methods

* Add changelog fragment
  • Loading branch information
davfsa authored Sep 16, 2021
1 parent cfcdb0b commit 81e93bd
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 115 deletions.
7 changes: 7 additions & 0 deletions changes/783.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Improve typing for message objects and message update methods
- Fix the use of `typing.Optional` where `undefined.UndefinedOr` should have been used
- Remove trying to acquire guild_id from the cached channel on PartialMessage
- Instead, clearly document the issue Discord imposes by not sending the guild_id
- `is_webhook` will now return `undefined.UNDEFINED` if the information is not available
- Fix logic in `is_human` to account for the changes in the typing
- Set `PartialMessage.member` to `undefined.UNDEFINED` when Discord edit the message to display an embed/attachment
81 changes: 30 additions & 51 deletions hikari/events/message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from hikari import intents
from hikari import snowflakes
from hikari import traits
from hikari import undefined
from hikari.events import base_events
from hikari.events import shard_events
from hikari.internal import attr_extensions
Expand All @@ -54,7 +55,6 @@
from hikari import embeds as embeds_
from hikari import guilds
from hikari import messages
from hikari import undefined
from hikari import users
from hikari.api import shard as shard_

Expand Down Expand Up @@ -343,33 +343,23 @@ def app(self) -> traits.RESTAware:
return self.message.app

@property
def author(self) -> typing.Optional[users.User]:
def author(self) -> undefined.UndefinedOr[users.User]:
"""User that sent the message.
Returns
-------
typing.Optional[hikari.users.User]
The user that sent the message.
This will be `builtins.None` in some cases, such as when Discord
updates a message with an embed for a URL preview.
This will be `hikari.undefined.UNDEFINED` in some cases such as when Discord
updates a message with an embed URL preview.
"""
return self.message.author

@property
def author_id(self) -> typing.Optional[snowflakes.Snowflake]:
def author_id(self) -> undefined.UndefinedOr[snowflakes.Snowflake]:
"""ID of the author that triggered this event.
Returns
-------
typing.Optional[hikari.snowflakes.Snowflake]
The ID of the author that triggered this event concerns.
This will be `builtins.None` in some cases, such as
when Discord updates a message with an embed for a URL preview.
This will be `hikari.undefined.UNDEFINED` in some cases such as when Discord
updates a message with an embed URL preview.
"""
author = self.message.author
return author.id if author is not None else None
return author.id if author is not undefined.UNDEFINED else undefined.UNDEFINED

@property
def channel_id(self) -> snowflakes.Snowflake:
Expand Down Expand Up @@ -403,7 +393,7 @@ def embeds(self) -> undefined.UndefinedOr[typing.Sequence[embeds_.Embed]]:
return self.message.embeds

@property
def is_bot(self) -> typing.Optional[bool]:
def is_bot(self) -> undefined.UndefinedOr[bool]:
"""Return `builtins.True` if the message is from a bot.
Returns
Expand All @@ -413,14 +403,15 @@ def is_bot(self) -> typing.Optional[bool]:
If the author is not known, due to the update event being caused
by Discord adding an embed preview to accompany a URL, then this
will return `builtins.None` instead.
will return `hikari.undefined.UNDEFINED` instead.
"""
if (author := self.message.author) is not None:
if (author := self.message.author) is not undefined.UNDEFINED:
return author.is_bot
return None

return undefined.UNDEFINED

@property
def is_human(self) -> typing.Optional[bool]:
def is_human(self) -> undefined.UndefinedOr[bool]:
"""Return `builtins.True` if the message was created by a human.
Returns
Expand All @@ -430,28 +421,31 @@ def is_human(self) -> typing.Optional[bool]:
If the author is not known, due to the update event being caused
by Discord adding an embed preview to accompany a URL, then this
may return `builtins.None` instead.
may return `hikari.undefined.UNDEFINED` instead.
"""
# Not second-guessing some weird edge case will occur in the future with this,
# so I am being safe rather than sorry.
if self.message.webhook_id is not None:
return False
if (webhook_id := self.message.webhook_id) is not undefined.UNDEFINED:
return webhook_id is None

if (author := self.message.author) is not None:
if (author := self.message.author) is not undefined.UNDEFINED:
return not author.is_bot

return None
return undefined.UNDEFINED

@property
def is_webhook(self) -> bool:
def is_webhook(self) -> undefined.UndefinedOr[bool]:
"""Return `builtins.True` if the message was created by a webhook.
Returns
-------
builtins.bool
`builtins.True` if from a webhook, or `builtins.False` otherwise.
"""
return self.message.webhook_id is not None
if (webhook_id := self.message.webhook_id) is not undefined.UNDEFINED:
return webhook_id is not None

return undefined.UNDEFINED

@property
@abc.abstractmethod
Expand Down Expand Up @@ -500,28 +494,13 @@ class GuildMessageUpdateEvent(MessageUpdateEvent):
# <<inherited docstring from ShardEvent>>

@property
def author(self) -> typing.Optional[users.User]:
"""User that sent the message.
Returns
-------
typing.Union[builtins.None, hikari.users.User, hikari.guilds.Member]
The user that sent the message.
This will be `builtins.None` in some cases, such as when Discord
updates a message with an embed for a URL preview or if the message
was sent by a webhook.
"""
return self.message.author

@property
def member(self) -> typing.Optional[guilds.Member]:
def member(self) -> undefined.UndefinedNoneOr[guilds.Member]:
"""Member that sent the message if provided by the event.
!!! note
This will be `builtins.None` in some cases, such as when Discord
updates a message with an embed for a URL preview or if the message
was sent by a webhook.
If the message is not in a guild, this will be `builtins.None`.
This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord
updates a message with an embed URL preview.
"""
return self.message.member

Expand All @@ -533,7 +512,7 @@ def get_member(self) -> typing.Optional[guilds.Member]:
typing.Optional[hikari.guilds.Member]
Cached object of the member that sent the message if found.
"""
if self.message.author is not None and isinstance(self.app, traits.CacheAware):
if self.message.author is not undefined.UNDEFINED and isinstance(self.app, traits.CacheAware):
return self.app.cache.get_member(self.guild_id, self.message.author.id)

return None
Expand Down
23 changes: 15 additions & 8 deletions hikari/impl/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,17 +2103,20 @@ def _deserialize_message_interaction(self, payload: data_binding.JSONObject) ->
def deserialize_partial_message( # noqa CFQ001 - Function too long
self, payload: data_binding.JSONObject
) -> message_models.PartialMessage:
author: typing.Optional[user_models.User] = None
author: undefined.UndefinedOr[user_models.User] = undefined.UNDEFINED
if author_pl := payload.get("author"):
author = self.deserialize_user(author_pl)

guild_id: typing.Optional[snowflakes.Snowflake] = None
member: typing.Optional[guild_models.Member] = None
member: undefined.UndefinedNoneOr[guild_models.Member] = None
if "guild_id" in payload:
guild_id = snowflakes.Snowflake(payload["guild_id"])

if author is not None and (member_pl := payload.get("member")):
if member_pl := payload.get("member"):
assert author is not None, "received message with a member object without a user object"
member = self.deserialize_member(member_pl, user=author, guild_id=guild_id)
else:
member = undefined.UNDEFINED

timestamp: undefined.UndefinedOr[datetime.datetime] = undefined.UNDEFINED
if "timestamp" in payload:
Expand Down Expand Up @@ -2246,13 +2249,17 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long
def deserialize_message( # noqa CFQ001 - Function too long
self, payload: data_binding.JSONObject
) -> message_models.Message:
guild_id = snowflakes.Snowflake(payload["guild_id"]) if "guild_id" in payload else None
author = self.deserialize_user(payload["author"])

member: typing.Optional[guild_models.Member] = None
if "member" in payload:
assert guild_id is not None
member = self.deserialize_member(payload["member"], guild_id=guild_id, user=author)
guild_id: typing.Optional[snowflakes.Snowflake] = None
member: undefined.UndefinedNoneOr[guild_models.Member] = None
if "guild_id" in payload:
guild_id = snowflakes.Snowflake(payload["guild_id"])

if member_pl := payload.get("member"):
member = self.deserialize_member(member_pl, user=author, guild_id=guild_id)
else:
member = undefined.UNDEFINED

edited_timestamp: typing.Optional[datetime.datetime] = None
if (raw_edited_timestamp := payload["edited_timestamp"]) is not None:
Expand Down
48 changes: 17 additions & 31 deletions hikari/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,25 +741,32 @@ class PartialMessage(snowflakes.Unique):
channel_id: snowflakes.Snowflake = attr.field(hash=False, eq=False, repr=True)
"""The ID of the channel that the message was sent in."""

_guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(hash=False, eq=False, repr=True)
#: Try to determine this best-effort in the property defined further
#: down.
guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(hash=False, eq=False, repr=True)
"""The ID of the guild that the message was sent in or `builtins.None` for messages out of guilds.
author: typing.Optional[users_.User] = attr.field(hash=False, eq=False, repr=True)
!!! warning
This will also be `builtins.None` for messages received from the REST API.
This is a Discord limitation as stated here https://github.com/discord/discord-api-docs/issues/912
"""

author: undefined.UndefinedOr[users_.User] = attr.field(hash=False, eq=False, repr=True)
"""The author of this message.
This will be `builtins.None` in some cases such as when Discord
updates a message with an embed URL preview.
This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord
updates a message with an embed URL preview or in messages fetched from the REST API.
"""

member: typing.Optional[guilds.Member] = attr.field(hash=False, eq=False, repr=False)
member: undefined.UndefinedNoneOr[guilds.Member] = attr.field(hash=False, eq=False, repr=False)
"""The member for the author who created the message.
If the message is not in a guild, this will be `builtins.None`.
This will also be `builtins.None` in some cases such as when Discord updates
a message with an embed URL preview, in messages fetched from the
REST API or messages sent by discord.
This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord
updates a message with an embed URL preview.
!!! warning
This will also be `builtins.None` for messages received from the REST API.
This is a Discord limitation as stated here https://github.com/discord/discord-api-docs/issues/912
"""

content: undefined.UndefinedNoneOr[str] = attr.field(hash=False, eq=False, repr=False)
Expand Down Expand Up @@ -860,27 +867,6 @@ class PartialMessage(snowflakes.Unique):
components: undefined.UndefinedOr[typing.Sequence[PartialComponent]] = attr.field(hash=False, repr=False)
"""Sequence of the components attached to this message."""

@property # TODO: update this while refactoring message structure
def guild_id(self) -> typing.Optional[snowflakes.Snowflake]:
"""ID of the guild that the message was sent in.
This will not be present on REST API responses if the application is
stateless or missing the `GUILDS` intent.
"""
if self._guild_id:
return self._guild_id

if not isinstance(self.app, traits.CacheAware):
return None
# Don't check the member, as if the guild_id is missing, the member
# will always be missing too.
channel = self.app.cache.get_guild_channel(self.channel_id)

if channel is None:
return None

return channel.guild_id

def make_link(self, guild: typing.Optional[snowflakes.SnowflakeishOr[guilds.PartialGuild]]) -> str:
"""Generate a jump link to this message.
Expand Down
17 changes: 10 additions & 7 deletions tests/hikari/events/test_message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def test_author_property(self, event, author):

@pytest.mark.parametrize(
("author", "expected_id"),
[(mock.Mock(spec_set=users.User, id=91827), 91827), (None, None)],
[
(mock.Mock(spec_set=users.User, id=91827), 91827),
(undefined.UNDEFINED, undefined.UNDEFINED),
],
)
def test_author_id_property(self, event, author, expected_id):
event.message.author = author
Expand All @@ -141,18 +144,18 @@ def test_is_bot_property(self, event, is_bot):
assert event.is_bot is is_bot

def test_is_bot_property_if_no_author(self, event):
event.message.author = None
assert event.is_bot is None
event.message.author = undefined.UNDEFINED
assert event.is_bot is undefined.UNDEFINED

@pytest.mark.parametrize(
("author", "webhook_id", "expected_is_human"),
[
(mock.Mock(spec_set=users.User, is_bot=True), 123, False),
(mock.Mock(spec_set=users.User, is_bot=True), None, False),
(mock.Mock(spec_set=users.User, is_bot=True), undefined.UNDEFINED, False),
(mock.Mock(spec_set=users.User, is_bot=False), 123, False),
(mock.Mock(spec_set=users.User, is_bot=False), None, True),
(None, 123, False),
(None, None, None),
(mock.Mock(spec_set=users.User, is_bot=False), undefined.UNDEFINED, True),
(undefined.UNDEFINED, 123, False),
(undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED),
],
)
def test_is_human_property(self, event, author, webhook_id, expected_is_human):
Expand Down
6 changes: 3 additions & 3 deletions tests/hikari/impl/test_entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3816,10 +3816,10 @@ def test_deserialize_partial_message(
def test_deserialize_partial_message_with_partial_fields(self, entity_factory_impl, message_payload):
message_payload["content"] = ""
message_payload["edited_timestamp"] = None
message_payload["member"] = None
message_payload["application"]["primary_sku_id"] = None
message_payload["application"]["icon"] = None
message_payload["referenced_message"] = None
del message_payload["member"]
del message_payload["message_reference"]["message_id"]
del message_payload["message_reference"]["guild_id"]
del message_payload["application"]["cover_image"]
Expand All @@ -3829,7 +3829,7 @@ def test_deserialize_partial_message_with_partial_fields(self, entity_factory_im
assert partial_message.content is None
assert partial_message.edited_timestamp is None
assert partial_message.guild_id is not None
assert partial_message.member is None
assert partial_message.member is undefined.UNDEFINED
assert partial_message.application.primary_sku_id is None
assert partial_message.application.icon_hash is None
assert partial_message.application.cover_image_hash is None
Expand All @@ -3844,7 +3844,7 @@ def test_deserialize_partial_message_with_unset_fields(self, entity_factory_impl
assert partial_message.id == 123
assert partial_message.channel_id == 456
assert partial_message.guild_id is None
assert partial_message.author is None
assert partial_message.author is undefined.UNDEFINED
assert partial_message.member is None
assert partial_message.content is undefined.UNDEFINED
assert partial_message.timestamp is undefined.UNDEFINED
Expand Down
15 changes: 0 additions & 15 deletions tests/hikari/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,6 @@ def test_make_link_when_guild_is_none(self, message):
message.channel_id = 456
assert message.make_link(None) == "https://discord.com/channels/@me/456/789"

def test_guild_id_when_guild_is_not_none(self, message):
message._guild_id = 123

assert message.guild_id == 123

def test_guild_id_when_guild_is_none(self, message):
message.app = mock.Mock()
message._guild_id = None
message.channel_id = 890
message.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(guild_id=456))

assert message.guild_id == 456

message.app.cache.get_guild_channel.assert_called_once_with(890)


@pytest.mark.asyncio()
class TestAsyncMessage:
Expand Down

0 comments on commit 81e93bd

Please sign in to comment.