Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve array support #8

Merged
merged 11 commits into from
Jun 2, 2022
2 changes: 2 additions & 0 deletions resource/_imports.pyi.em
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ first_party_imports = set()
generator(component, defined_classes, third_party_imports, first_party_imports)
}@
import typing
import array
import numpy as np
@[if len(third_party_imports) > 0]@

@[for statement in sorted(third_party_imports)]@
Expand Down
10 changes: 5 additions & 5 deletions resource/_msg.pyi.em
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ class Metaclass_@(message.structure.namespaced_type.name)(type):
) -> typing.Mapping[str, typing.Any]: ...
@[for constant in message.constants]@
@@property
def @(constant.name)(self) -> @(to_type_annotation(current_namespace, defined_classes, constant.type)): ...
def @(constant.name)(self) -> @(to_type_annotation(current_namespace, defined_classes, constant.type).getter): ...
@[end for]@
@[for member in message.structure.members]@
@[ if member.has_annotation('default')]@
@@property
def @(member.name.upper())__DEFAULT(cls) -> @(to_type_annotation(current_namespace, defined_classes, member.type)): ...
def @(member.name.upper())__DEFAULT(cls) -> @(to_type_annotation(current_namespace, defined_classes, member.type).getter): ...
@[ end if]@
@[end for]@

Expand All @@ -69,7 +69,7 @@ class @(message.structure.namespaced_type.name)(metaclass=Metaclass_@(message.st
self,
*,
@[for name, annotation, noqa_string in members]@
@(name): @(annotation) = ...,@(noqa_string)
@(name): @(annotation.getter) = ...,@(noqa_string)
@[end for]@
**kwargs: typing.Any,
) -> None: ...
Expand All @@ -80,7 +80,7 @@ class @(message.structure.namespaced_type.name)(metaclass=Metaclass_@(message.st
# Members
@[for name, annotation, noqa_string in members]@
@@property@(noqa_string)
def @(name)(self) -> @(annotation): ...@(noqa_string)
def @(name)(self) -> @(annotation.getter): ...@(noqa_string)
@@@(name).setter@(noqa_string)
def @(name)(self, value: @(annotation)) -> None: ...@(noqa_string)
def @(name)(self, value: @(annotation.setter)) -> None: ...@(noqa_string)
@[end for]@
59 changes: 50 additions & 9 deletions rosidl_generator_mypy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pathlib
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, NamedTuple, Optional, Set, Tuple

from rosidl_cmake import (
convert_camel_case_to_lower_case_underscore,
Expand All @@ -22,6 +22,13 @@
)
from rosidl_parser.parser import parse_idl_file

SPECIAL_NESTED_BASIC_TYPES = ["int", "float"]


class Annotation(NamedTuple):
getter: str
setter: str


def generate(generator_arguments_file: str) -> List[str]:
mapping = {
Expand Down Expand Up @@ -102,33 +109,67 @@ def get_defined_classes(content: IdlContent) -> Set[str]:

def to_type_annotation(
current_namespace: NamespacedType, defined_classes: Set[str], type_: AbstractType
) -> str:
) -> Annotation:
if isinstance(type_, NamespacedType):
if type_.namespaces == current_namespace.namespaces:
if type_.name in defined_classes:
# member is defined in the same module, so no need to add namespaces
return '"{}"'.format(type_.name)
annotation = '"{}"'.format(type_.name)
return Annotation(annotation, annotation)

# NOTE: We export .pyi files, which don't affect the Python code execution at all.
# As mypy solves the import cycles properly,
# we import classes from not a module but a package.
# (i.e. in the same way as imports for other packages)

return "{}.{}".format(".".join(type_.namespaces), type_.name)
annotation = "{}.{}".format(".".join(type_.namespaces), type_.name)
return Annotation(annotation, annotation)

try:
ret = generate_py_impl.get_python_type(type_)
if ret is not None:
return str(ret)
return Annotation(str(ret), str(ret))
except Exception:
pass

if isinstance(type_, (AbstractSequence, Array)):
return "typing.Sequence[{}]".format(
to_type_annotation(current_namespace, defined_classes, type_.value_type)
if isinstance(type_, Array):
# The type_ will be Array for bounded lists
type_annotation = to_type_annotation(
current_namespace, defined_classes, type_.value_type
)
if type_annotation.getter in SPECIAL_NESTED_BASIC_TYPES:
# eg: int64[4]
return Annotation(
"np.ndarray",
"typing.Union[typing.Sequence[{}], np.ndarray]".format(
type_annotation.setter
),
)

# eg: std_msgs/Header[4]
return Annotation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the specification.
https://design.ros2.org/articles/idl_interface_definition.html

According to the Type Mapping section on the document, shouldn't we check type_annotation.getter in SPECIAL_NESTED_BASIC_TYPES in the same way as AbstractSequence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure... I've tested this against most of the standard ros interfaces (eg sensor_msgs) and I think everything looked correct.
It could be prudent to write some tests which colcon build a bunch of messages and make assertions about the pyi code it generates but I'm not sure I'll have time to do that right now.

Copy link
Member

@bonprosoft bonprosoft May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to the rosidl_python implementation, and it actually follows the specification.

I think the member initialization part is helpful to understand this behavior:
https://github.com/ros2/rosidl_python/blob/6675823dd6123360912e233a830d1edab89c7651/rosidl_generator_py/resource/_msg.py.em#L320-L332

I've also tested with the following message definition and got the same result that I expected.

Definition

int64[] integer_array
int64[4] integer_fixed_array

std_msgs/Header[] header_array
std_msgs/Header[4] header_fixed_array

integer_array

    @integer_array.setter
    def integer_array(self, value):
        if isinstance(value, array.array):
            assert value.typecode == 'q', \
                "The 'integer_array' array.array() must have the type code of 'q'"
            self._integer_array = value
            return
        if __debug__:
            from collections.abc import Sequence
            from collections.abc import Set
            from collections import UserList
            from collections import UserString
            assert \
                ((isinstance(value, Sequence) or
                  isinstance(value, Set) or
                  isinstance(value, UserList)) and
                 not isinstance(value, str) and
                 not isinstance(value, UserString) and
                 all(isinstance(v, int) for v in value) and
                 all(val >= -9223372036854775808 and val < 9223372036854775808 for val in value)), \
                "The 'integer_array' field must be a set or sequence and each value of type 'int' and each integer in [-9223372036854775808, 9223372036854775807]"
        self._integer_array = array.array('q', value)

integer_fixed_array

    @integer_fixed_array.setter
    def integer_fixed_array(self, value):
        if isinstance(value, numpy.ndarray):
            assert value.dtype == numpy.int64, \
                "The 'integer_fixed_array' numpy.ndarray() must have the dtype of 'numpy.int64'"
            assert value.size == 4, \
                "The 'integer_fixed_array' numpy.ndarray() must have a size of 4"
            self._integer_fixed_array = value
            return
        if __debug__:
            from collections.abc import Sequence
            from collections.abc import Set
            from collections import UserList
            from collections import UserString
            assert \
                ((isinstance(value, Sequence) or
                  isinstance(value, Set) or
                  isinstance(value, UserList)) and
                 not isinstance(value, str) and
                 not isinstance(value, UserString) and
                 len(value) == 4 and
                 all(isinstance(v, int) for v in value) and
                 all(val >= -9223372036854775808 and val < 9223372036854775808 for val in value)), \
                "The 'integer_fixed_array' field must be a set or sequence with length 4 and each value of type 'int' and each integer in [-9223372036854775808, 9223372036854775807]"
        self._integer_fixed_array = numpy.array(value, dtype=numpy.int64)

header_array

    @header_array.setter
    def header_array(self, value):
        if __debug__:
            from std_msgs.msg import Header
            from collections.abc import Sequence
            from collections.abc import Set
            from collections import UserList
            from collections import UserString
            assert \
                ((isinstance(value, Sequence) or
                  isinstance(value, Set) or
                  isinstance(value, UserList)) and
                 not isinstance(value, str) and
                 not isinstance(value, UserString) and
                 all(isinstance(v, Header) for v in value) and
                 True), \
                "The 'header_array' field must be a set or sequence and each value of type 'Header'"
        self._header_array = value

header_fixed_array

    @header_fixed_array.setter
    def header_fixed_array(self, value):
        if __debug__:
            from std_msgs.msg import Header
            from collections.abc import Sequence
            from collections.abc import Set
            from collections import UserList
            from collections import UserString
            assert \
                ((isinstance(value, Sequence) or
                  isinstance(value, Set) or
                  isinstance(value, UserList)) and
                 not isinstance(value, str) and
                 not isinstance(value, UserString) and
                 len(value) == 4 and
                 all(isinstance(v, Header) for v in value) and
                 True), \
                "The 'header_fixed_array' field must be a set or sequence with length 4 and each value of type 'Header'"
        self._header_fixed_array = value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MrBlenny I also encountered the same issue as you 😅
Actually, this PR is great, and It would be wonderful to merge this PR soon.
If you don't have time to update this PR, please let me know. I am happy to help you.

"typing.Sequence[{}]".format(type_annotation.getter),
"typing.Sequence[{}]".format(type_annotation.setter),
)
if isinstance(type_, AbstractSequence):
# The type_ will be AbstractSequence for unbounded lists
type_annotation = to_type_annotation(
current_namespace, defined_classes, type_.value_type
)
if type_annotation.getter in SPECIAL_NESTED_BASIC_TYPES:
# eg: int64[]
return Annotation(
"array.array[{}]".format(type_annotation.getter),
"typing.Sequence[{}]".format(type_annotation.setter),
)

# eg: std_msgs/Header[]
return Annotation(
"typing.Sequence[{}]".format(type_annotation.getter),
"typing.Sequence[{}]".format(type_annotation.setter),
)

return str(type_)
return Annotation(str(type_), str(type_))


def _get_import_statement(
Expand Down