diff --git a/changes/783.bugfix.md b/changes/783.bugfix.md new file mode 100644 index 0000000000..490b2f7ebb --- /dev/null +++ b/changes/783.bugfix.md @@ -0,0 +1,7 @@ +Improve typing for message objects and message update methods +- Fix the use of `typing.Optional` where `undefined.UndefinedOr` should have been used +- Remove trying to acquire guild_id from the cached channel on PartialMessage + - Instead, clearly document the issue Discord imposes by not sending the guild_id +- `is_webhook` will now return `undefined.UNDEFINED` if the information is not available +- Fix logic in `is_human` to account for the changes in the typing +- Set `PartialMessage.member` to `undefined.UNDEFINED` when Discord edit the message to display an embed/attachment diff --git a/hikari/events/message_events.py b/hikari/events/message_events.py index 93b333c012..67f50c52d6 100644 --- a/hikari/events/message_events.py +++ b/hikari/events/message_events.py @@ -46,6 +46,7 @@ from hikari import intents from hikari import snowflakes from hikari import traits +from hikari import undefined from hikari.events import base_events from hikari.events import shard_events from hikari.internal import attr_extensions @@ -54,7 +55,6 @@ from hikari import embeds as embeds_ from hikari import guilds from hikari import messages - from hikari import undefined from hikari import users from hikari.api import shard as shard_ @@ -343,33 +343,23 @@ def app(self) -> traits.RESTAware: return self.message.app @property - def author(self) -> typing.Optional[users.User]: + def author(self) -> undefined.UndefinedOr[users.User]: """User that sent the message. - Returns - ------- - typing.Optional[hikari.users.User] - The user that sent the message. - - This will be `builtins.None` in some cases, such as when Discord - updates a message with an embed for a URL preview. + This will be `hikari.undefined.UNDEFINED` in some cases such as when Discord + updates a message with an embed URL preview. """ return self.message.author @property - def author_id(self) -> typing.Optional[snowflakes.Snowflake]: + def author_id(self) -> undefined.UndefinedOr[snowflakes.Snowflake]: """ID of the author that triggered this event. - Returns - ------- - typing.Optional[hikari.snowflakes.Snowflake] - The ID of the author that triggered this event concerns. - - This will be `builtins.None` in some cases, such as - when Discord updates a message with an embed for a URL preview. + This will be `hikari.undefined.UNDEFINED` in some cases such as when Discord + updates a message with an embed URL preview. """ author = self.message.author - return author.id if author is not None else None + return author.id if author is not undefined.UNDEFINED else undefined.UNDEFINED @property def channel_id(self) -> snowflakes.Snowflake: @@ -403,7 +393,7 @@ def embeds(self) -> undefined.UndefinedOr[typing.Sequence[embeds_.Embed]]: return self.message.embeds @property - def is_bot(self) -> typing.Optional[bool]: + def is_bot(self) -> undefined.UndefinedOr[bool]: """Return `builtins.True` if the message is from a bot. Returns @@ -413,14 +403,15 @@ def is_bot(self) -> typing.Optional[bool]: If the author is not known, due to the update event being caused by Discord adding an embed preview to accompany a URL, then this - will return `builtins.None` instead. + will return `hikari.undefined.UNDEFINED` instead. """ - if (author := self.message.author) is not None: + if (author := self.message.author) is not undefined.UNDEFINED: return author.is_bot - return None + + return undefined.UNDEFINED @property - def is_human(self) -> typing.Optional[bool]: + def is_human(self) -> undefined.UndefinedOr[bool]: """Return `builtins.True` if the message was created by a human. Returns @@ -430,20 +421,20 @@ def is_human(self) -> typing.Optional[bool]: If the author is not known, due to the update event being caused by Discord adding an embed preview to accompany a URL, then this - may return `builtins.None` instead. + may return `hikari.undefined.UNDEFINED` instead. """ # Not second-guessing some weird edge case will occur in the future with this, # so I am being safe rather than sorry. - if self.message.webhook_id is not None: - return False + if (webhook_id := self.message.webhook_id) is not undefined.UNDEFINED: + return webhook_id is None - if (author := self.message.author) is not None: + if (author := self.message.author) is not undefined.UNDEFINED: return not author.is_bot - return None + return undefined.UNDEFINED @property - def is_webhook(self) -> bool: + def is_webhook(self) -> undefined.UndefinedOr[bool]: """Return `builtins.True` if the message was created by a webhook. Returns @@ -451,7 +442,10 @@ def is_webhook(self) -> bool: builtins.bool `builtins.True` if from a webhook, or `builtins.False` otherwise. """ - return self.message.webhook_id is not None + if (webhook_id := self.message.webhook_id) is not undefined.UNDEFINED: + return webhook_id is not None + + return undefined.UNDEFINED @property @abc.abstractmethod @@ -500,28 +494,13 @@ class GuildMessageUpdateEvent(MessageUpdateEvent): # <> @property - def author(self) -> typing.Optional[users.User]: - """User that sent the message. - - Returns - ------- - typing.Union[builtins.None, hikari.users.User, hikari.guilds.Member] - The user that sent the message. - - This will be `builtins.None` in some cases, such as when Discord - updates a message with an embed for a URL preview or if the message - was sent by a webhook. - """ - return self.message.author - - @property - def member(self) -> typing.Optional[guilds.Member]: + def member(self) -> undefined.UndefinedNoneOr[guilds.Member]: """Member that sent the message if provided by the event. - !!! note - This will be `builtins.None` in some cases, such as when Discord - updates a message with an embed for a URL preview or if the message - was sent by a webhook. + If the message is not in a guild, this will be `builtins.None`. + + This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord + updates a message with an embed URL preview. """ return self.message.member @@ -533,7 +512,7 @@ def get_member(self) -> typing.Optional[guilds.Member]: typing.Optional[hikari.guilds.Member] Cached object of the member that sent the message if found. """ - if self.message.author is not None and isinstance(self.app, traits.CacheAware): + if self.message.author is not undefined.UNDEFINED and isinstance(self.app, traits.CacheAware): return self.app.cache.get_member(self.guild_id, self.message.author.id) return None diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index 58a3d90847..610a1408ea 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -2103,17 +2103,20 @@ def _deserialize_message_interaction(self, payload: data_binding.JSONObject) -> def deserialize_partial_message( # noqa CFQ001 - Function too long self, payload: data_binding.JSONObject ) -> message_models.PartialMessage: - author: typing.Optional[user_models.User] = None + author: undefined.UndefinedOr[user_models.User] = undefined.UNDEFINED if author_pl := payload.get("author"): author = self.deserialize_user(author_pl) guild_id: typing.Optional[snowflakes.Snowflake] = None - member: typing.Optional[guild_models.Member] = None + member: undefined.UndefinedNoneOr[guild_models.Member] = None if "guild_id" in payload: guild_id = snowflakes.Snowflake(payload["guild_id"]) - if author is not None and (member_pl := payload.get("member")): + if member_pl := payload.get("member"): + assert author is not None, "received message with a member object without a user object" member = self.deserialize_member(member_pl, user=author, guild_id=guild_id) + else: + member = undefined.UNDEFINED timestamp: undefined.UndefinedOr[datetime.datetime] = undefined.UNDEFINED if "timestamp" in payload: @@ -2246,13 +2249,17 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long def deserialize_message( # noqa CFQ001 - Function too long self, payload: data_binding.JSONObject ) -> message_models.Message: - guild_id = snowflakes.Snowflake(payload["guild_id"]) if "guild_id" in payload else None author = self.deserialize_user(payload["author"]) - member: typing.Optional[guild_models.Member] = None - if "member" in payload: - assert guild_id is not None - member = self.deserialize_member(payload["member"], guild_id=guild_id, user=author) + guild_id: typing.Optional[snowflakes.Snowflake] = None + member: undefined.UndefinedNoneOr[guild_models.Member] = None + if "guild_id" in payload: + guild_id = snowflakes.Snowflake(payload["guild_id"]) + + if member_pl := payload.get("member"): + member = self.deserialize_member(member_pl, user=author, guild_id=guild_id) + else: + member = undefined.UNDEFINED edited_timestamp: typing.Optional[datetime.datetime] = None if (raw_edited_timestamp := payload["edited_timestamp"]) is not None: diff --git a/hikari/messages.py b/hikari/messages.py index 68064421e2..df7074c43f 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -741,25 +741,32 @@ class PartialMessage(snowflakes.Unique): channel_id: snowflakes.Snowflake = attr.field(hash=False, eq=False, repr=True) """The ID of the channel that the message was sent in.""" - _guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(hash=False, eq=False, repr=True) - #: Try to determine this best-effort in the property defined further - #: down. + guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(hash=False, eq=False, repr=True) + """The ID of the guild that the message was sent in or `builtins.None` for messages out of guilds. - author: typing.Optional[users_.User] = attr.field(hash=False, eq=False, repr=True) + !!! warning + This will also be `builtins.None` for messages received from the REST API. + This is a Discord limitation as stated here https://github.com/discord/discord-api-docs/issues/912 + """ + + author: undefined.UndefinedOr[users_.User] = attr.field(hash=False, eq=False, repr=True) """The author of this message. - This will be `builtins.None` in some cases such as when Discord - updates a message with an embed URL preview. + This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord + updates a message with an embed URL preview or in messages fetched from the REST API. """ - member: typing.Optional[guilds.Member] = attr.field(hash=False, eq=False, repr=False) + member: undefined.UndefinedNoneOr[guilds.Member] = attr.field(hash=False, eq=False, repr=False) """The member for the author who created the message. If the message is not in a guild, this will be `builtins.None`. - This will also be `builtins.None` in some cases such as when Discord updates - a message with an embed URL preview, in messages fetched from the - REST API or messages sent by discord. + This will also be `hikari.undefined.UNDEFINED` in some cases such as when Discord + updates a message with an embed URL preview. + + !!! warning + This will also be `builtins.None` for messages received from the REST API. + This is a Discord limitation as stated here https://github.com/discord/discord-api-docs/issues/912 """ content: undefined.UndefinedNoneOr[str] = attr.field(hash=False, eq=False, repr=False) @@ -860,27 +867,6 @@ class PartialMessage(snowflakes.Unique): components: undefined.UndefinedOr[typing.Sequence[PartialComponent]] = attr.field(hash=False, repr=False) """Sequence of the components attached to this message.""" - @property # TODO: update this while refactoring message structure - def guild_id(self) -> typing.Optional[snowflakes.Snowflake]: - """ID of the guild that the message was sent in. - - This will not be present on REST API responses if the application is - stateless or missing the `GUILDS` intent. - """ - if self._guild_id: - return self._guild_id - - if not isinstance(self.app, traits.CacheAware): - return None - # Don't check the member, as if the guild_id is missing, the member - # will always be missing too. - channel = self.app.cache.get_guild_channel(self.channel_id) - - if channel is None: - return None - - return channel.guild_id - def make_link(self, guild: typing.Optional[snowflakes.SnowflakeishOr[guilds.PartialGuild]]) -> str: """Generate a jump link to this message. diff --git a/tests/hikari/events/test_message_events.py b/tests/hikari/events/test_message_events.py index 138a031ea6..964f62454f 100644 --- a/tests/hikari/events/test_message_events.py +++ b/tests/hikari/events/test_message_events.py @@ -120,7 +120,10 @@ def test_author_property(self, event, author): @pytest.mark.parametrize( ("author", "expected_id"), - [(mock.Mock(spec_set=users.User, id=91827), 91827), (None, None)], + [ + (mock.Mock(spec_set=users.User, id=91827), 91827), + (undefined.UNDEFINED, undefined.UNDEFINED), + ], ) def test_author_id_property(self, event, author, expected_id): event.message.author = author @@ -141,18 +144,18 @@ def test_is_bot_property(self, event, is_bot): assert event.is_bot is is_bot def test_is_bot_property_if_no_author(self, event): - event.message.author = None - assert event.is_bot is None + event.message.author = undefined.UNDEFINED + assert event.is_bot is undefined.UNDEFINED @pytest.mark.parametrize( ("author", "webhook_id", "expected_is_human"), [ (mock.Mock(spec_set=users.User, is_bot=True), 123, False), - (mock.Mock(spec_set=users.User, is_bot=True), None, False), + (mock.Mock(spec_set=users.User, is_bot=True), undefined.UNDEFINED, False), (mock.Mock(spec_set=users.User, is_bot=False), 123, False), - (mock.Mock(spec_set=users.User, is_bot=False), None, True), - (None, 123, False), - (None, None, None), + (mock.Mock(spec_set=users.User, is_bot=False), undefined.UNDEFINED, True), + (undefined.UNDEFINED, 123, False), + (undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED), ], ) def test_is_human_property(self, event, author, webhook_id, expected_is_human): diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index 99b58e8472..5bbd7bd9d0 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -3816,10 +3816,10 @@ def test_deserialize_partial_message( def test_deserialize_partial_message_with_partial_fields(self, entity_factory_impl, message_payload): message_payload["content"] = "" message_payload["edited_timestamp"] = None - message_payload["member"] = None message_payload["application"]["primary_sku_id"] = None message_payload["application"]["icon"] = None message_payload["referenced_message"] = None + del message_payload["member"] del message_payload["message_reference"]["message_id"] del message_payload["message_reference"]["guild_id"] del message_payload["application"]["cover_image"] @@ -3829,7 +3829,7 @@ def test_deserialize_partial_message_with_partial_fields(self, entity_factory_im assert partial_message.content is None assert partial_message.edited_timestamp is None assert partial_message.guild_id is not None - assert partial_message.member is None + assert partial_message.member is undefined.UNDEFINED assert partial_message.application.primary_sku_id is None assert partial_message.application.icon_hash is None assert partial_message.application.cover_image_hash is None @@ -3844,7 +3844,7 @@ def test_deserialize_partial_message_with_unset_fields(self, entity_factory_impl assert partial_message.id == 123 assert partial_message.channel_id == 456 assert partial_message.guild_id is None - assert partial_message.author is None + assert partial_message.author is undefined.UNDEFINED assert partial_message.member is None assert partial_message.content is undefined.UNDEFINED assert partial_message.timestamp is undefined.UNDEFINED diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 54350e70b2..c8c985f9a6 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -166,21 +166,6 @@ def test_make_link_when_guild_is_none(self, message): message.channel_id = 456 assert message.make_link(None) == "https://discord.com/channels/@me/456/789" - def test_guild_id_when_guild_is_not_none(self, message): - message._guild_id = 123 - - assert message.guild_id == 123 - - def test_guild_id_when_guild_is_none(self, message): - message.app = mock.Mock() - message._guild_id = None - message.channel_id = 890 - message.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(guild_id=456)) - - assert message.guild_id == 456 - - message.app.cache.get_guild_channel.assert_called_once_with(890) - @pytest.mark.asyncio() class TestAsyncMessage: