From 88a70485137db5ae9b5eb0c9113d16a48fad1c1a Mon Sep 17 00:00:00 2001 From: Justin Boswell Date: Thu, 30 Nov 2023 23:07:03 -0800 Subject: [PATCH] Added support for template deduction guides * Added DeductionGuide as a language element --- cxxheaderparser/parser.py | 51 ++++++++++++++++++----- cxxheaderparser/simple.py | 9 ++++ cxxheaderparser/types.py | 18 ++++++++ cxxheaderparser/visitor.py | 13 ++++++ tests/test_template.py | 84 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 164 insertions(+), 11 deletions(-) diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index a8cdf95..5bef044 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -25,6 +25,7 @@ Concept, DecltypeSpecifier, DecoratedType, + DeductionGuide, EnumDecl, Enumerator, Field, @@ -1868,10 +1869,9 @@ def _parse_parameters( _auto_return_typename = PQName([AutoSpecifier()]) def _parse_trailing_return_type( - self, fn: typing.Union[Function, FunctionType] - ) -> None: + self, return_type: typing.Optional[DecoratedType] + ) -> DecoratedType: # entry is "->" - return_type = fn.return_type if not ( isinstance(return_type, Type) and not return_type.const @@ -1890,8 +1890,7 @@ def _parse_trailing_return_type( dtype = self._parse_cv_ptr(parsed_type) - fn.has_trailing_return = True - fn.return_type = dtype + return dtype def _parse_fn_end(self, fn: Function) -> None: """ @@ -1918,7 +1917,9 @@ def _parse_fn_end(self, fn: Function) -> None: fn.raw_requires = self._parse_requires(rtok) if self.lex.token_if("ARROW"): - self._parse_trailing_return_type(fn) + return_type = self._parse_trailing_return_type(fn.return_type) + fn.has_trailing_return = True + fn.return_type = return_type if self.lex.token_if("{"): self._discard_contents("{", "}") @@ -1966,7 +1967,9 @@ def _parse_method_end(self, method: Method) -> None: elif tok_value in ("&", "&&"): method.ref_qualifier = tok_value elif tok_value == "->": - self._parse_trailing_return_type(method) + return_type = self._parse_trailing_return_type(method.return_type) + method.has_trailing_return = True + method.return_type = return_type if self.lex.token_if("{"): self._discard_contents("{", "}") method.has_body = True @@ -2000,6 +2003,7 @@ def _parse_function( is_friend: bool, is_typedef: bool, msvc_convention: typing.Optional[LexToken], + is_guide: bool = False, ) -> bool: """ Assumes the caller has already consumed the return type and name, this consumes the @@ -2076,7 +2080,21 @@ def _parse_function( self.visitor.on_method_impl(state, method) return method.has_body or method.has_trailing_return - + elif is_guide: + assert isinstance(state, (ExternBlockState, NamespaceBlockState)) + if not self.lex.token_if("ARROW"): + raise self._parse_error(None, expected="Trailing return type") + return_type = self._parse_trailing_return_type( + Type(PQName([AutoSpecifier()])) + ) + guide = DeductionGuide( + return_type, + name=pqname, + parameters=params, + doxygen=doxygen, + ) + self.visitor.on_deduction_guide(state, guide) + return False else: assert return_type is not None fn = Function( @@ -2210,7 +2228,9 @@ def _parse_cv_ptr_or_fn( assert not isinstance(dtype, FunctionType) dtype = dtype_fn = FunctionType(dtype, fn_params, vararg) if self.lex.token_if("ARROW"): - self._parse_trailing_return_type(dtype_fn) + return_type = self._parse_trailing_return_type(dtype_fn.return_type) + dtype_fn.has_trailing_return = True + dtype_fn.return_type = return_type else: msvc_convention = None @@ -2391,6 +2411,7 @@ def _parse_decl( destructor = False op = None msvc_convention = None + is_guide = False # If we have a leading (, that's either an obnoxious grouping # paren or it's a constructor @@ -2441,8 +2462,15 @@ def _parse_decl( # grouping paren like "void (name(int x));" toks = self._consume_balanced_tokens(tok) - # .. not sure what it's grouping, so put it back? - self.lex.return_tokens(toks[1:-1]) + # check to see if the next token is an arrow, and thus a trailing return + if self.lex.token_peek_if("ARROW"): + self.lex.return_tokens(toks) + # the leading name of the class/ctor has been parsed as a type before the parens + pqname = parsed_type.typename + is_guide = True + else: + # .. not sure what it's grouping, so put it back? + self.lex.return_tokens(toks[1:-1]) if dtype: msvc_convention = self.lex.token_if_val(*self._msvc_conventions) @@ -2473,6 +2501,7 @@ def _parse_decl( is_friend, is_typedef, msvc_convention, + is_guide, ) elif msvc_convention: raise self._parse_error(msvc_convention) diff --git a/cxxheaderparser/simple.py b/cxxheaderparser/simple.py index 2538683..02bb56e 100644 --- a/cxxheaderparser/simple.py +++ b/cxxheaderparser/simple.py @@ -35,6 +35,7 @@ from .types import ( ClassDecl, Concept, + DeductionGuide, EnumDecl, Field, ForwardDecl, @@ -123,6 +124,9 @@ class NamespaceScope: #: Child namespaces namespaces: typing.Dict[str, "NamespaceScope"] = field(default_factory=dict) + #: Deduction guides + deduction_guides: typing.List[DeductionGuide] = field(default_factory=list) + Block = typing.Union[ClassScope, NamespaceScope] @@ -317,6 +321,11 @@ def on_class_friend(self, state: SClassBlockState, friend: FriendDecl) -> None: def on_class_end(self, state: SClassBlockState) -> None: pass + def on_deduction_guide( + self, state: SNonClassBlockState, guide: DeductionGuide + ) -> None: + state.user_data.deduction_guides.append(guide) + def parse_string( content: str, diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index d15b76f..1aa0b99 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -896,3 +896,21 @@ class UsingAlias: #: Documentation if present doxygen: typing.Optional[str] = None + + +@dataclass +class DeductionGuide: + """ + .. code-block:: c++ + + template + MyClass(T) -> MyClass(int); + """ + + #: Only constructors and destructors don't have a return type + result_type: typing.Optional[DecoratedType] + + name: PQName + parameters: typing.List[Parameter] + + doxygen: typing.Optional[str] = None diff --git a/cxxheaderparser/visitor.py b/cxxheaderparser/visitor.py index df0129d..000f8d0 100644 --- a/cxxheaderparser/visitor.py +++ b/cxxheaderparser/visitor.py @@ -9,6 +9,7 @@ from .types import ( Concept, + DeductionGuide, EnumDecl, Field, ForwardDecl, @@ -236,6 +237,13 @@ def on_class_end(self, state: ClassBlockState) -> None: ``on_variable`` for each instance declared. """ + def on_deduction_guide( + self, state: NonClassBlockState, guide: DeductionGuide + ) -> None: + """ + Called when a deduction guide is encountered + """ + class NullVisitor: """ @@ -318,5 +326,10 @@ def on_class_method(self, state: ClassBlockState, method: Method) -> None: def on_class_end(self, state: ClassBlockState) -> None: return None + def on_deduction_guide( + self, state: NonClassBlockState, guide: DeductionGuide + ) -> None: + return None + null_visitor = NullVisitor() diff --git a/tests/test_template.py b/tests/test_template.py index 344e98f..ffdf8e9 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -5,6 +5,7 @@ BaseClass, ClassDecl, DecltypeSpecifier, + DeductionGuide, Field, ForwardDecl, Function, @@ -2163,3 +2164,86 @@ def test_member_class_template_specialization() -> None: ] ) ) + + +def test_template_deduction_guide() -> None: + content = """ + template > + Error(std::basic_string_view) -> Error; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + deduction_guides=[ + DeductionGuide( + result_type=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="Error", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="string" + ), + ] + ) + ) + ) + ] + ), + ) + ] + ) + ), + name=PQName(segments=[NameSpecifier(name="Error")]), + parameters=[ + Parameter( + type=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="basic_string_view", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="CharT" + ) + ] + ) + ) + ), + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="Traits" + ) + ] + ) + ) + ), + ] + ), + ), + ] + ) + ) + ) + ], + ) + ] + ) + )