Skip to content

Commit

Permalink
Merge pull request #87 from justinboswell/ctad
Browse files Browse the repository at this point in the history
Added support for template deduction guides
  • Loading branch information
virtuald authored Dec 5, 2023
2 parents 64c5290 + 88a7048 commit 29b71ab
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 11 deletions.
51 changes: 40 additions & 11 deletions cxxheaderparser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Concept,
DecltypeSpecifier,
DecoratedType,
DeductionGuide,
EnumDecl,
Enumerator,
Field,
Expand Down Expand Up @@ -1868,10 +1869,9 @@ def _parse_parameters(
_auto_return_typename = PQName([AutoSpecifier()])

def _parse_trailing_return_type(
self, fn: typing.Union[Function, FunctionType]
) -> None:
self, return_type: typing.Optional[DecoratedType]
) -> DecoratedType:
# entry is "->"
return_type = fn.return_type
if not (
isinstance(return_type, Type)
and not return_type.const
Expand All @@ -1890,8 +1890,7 @@ def _parse_trailing_return_type(

dtype = self._parse_cv_ptr(parsed_type)

fn.has_trailing_return = True
fn.return_type = dtype
return dtype

def _parse_fn_end(self, fn: Function) -> None:
"""
Expand All @@ -1918,7 +1917,9 @@ def _parse_fn_end(self, fn: Function) -> None:
fn.raw_requires = self._parse_requires(rtok)

if self.lex.token_if("ARROW"):
self._parse_trailing_return_type(fn)
return_type = self._parse_trailing_return_type(fn.return_type)
fn.has_trailing_return = True
fn.return_type = return_type

if self.lex.token_if("{"):
self._discard_contents("{", "}")
Expand Down Expand Up @@ -1966,7 +1967,9 @@ def _parse_method_end(self, method: Method) -> None:
elif tok_value in ("&", "&&"):
method.ref_qualifier = tok_value
elif tok_value == "->":
self._parse_trailing_return_type(method)
return_type = self._parse_trailing_return_type(method.return_type)
method.has_trailing_return = True
method.return_type = return_type
if self.lex.token_if("{"):
self._discard_contents("{", "}")
method.has_body = True
Expand Down Expand Up @@ -2000,6 +2003,7 @@ def _parse_function(
is_friend: bool,
is_typedef: bool,
msvc_convention: typing.Optional[LexToken],
is_guide: bool = False,
) -> bool:
"""
Assumes the caller has already consumed the return type and name, this consumes the
Expand Down Expand Up @@ -2076,7 +2080,21 @@ def _parse_function(
self.visitor.on_method_impl(state, method)

return method.has_body or method.has_trailing_return

elif is_guide:
assert isinstance(state, (ExternBlockState, NamespaceBlockState))
if not self.lex.token_if("ARROW"):
raise self._parse_error(None, expected="Trailing return type")
return_type = self._parse_trailing_return_type(
Type(PQName([AutoSpecifier()]))
)
guide = DeductionGuide(
return_type,
name=pqname,
parameters=params,
doxygen=doxygen,
)
self.visitor.on_deduction_guide(state, guide)
return False
else:
assert return_type is not None
fn = Function(
Expand Down Expand Up @@ -2210,7 +2228,9 @@ def _parse_cv_ptr_or_fn(
assert not isinstance(dtype, FunctionType)
dtype = dtype_fn = FunctionType(dtype, fn_params, vararg)
if self.lex.token_if("ARROW"):
self._parse_trailing_return_type(dtype_fn)
return_type = self._parse_trailing_return_type(dtype_fn.return_type)
dtype_fn.has_trailing_return = True
dtype_fn.return_type = return_type

else:
msvc_convention = None
Expand Down Expand Up @@ -2391,6 +2411,7 @@ def _parse_decl(
destructor = False
op = None
msvc_convention = None
is_guide = False

# If we have a leading (, that's either an obnoxious grouping
# paren or it's a constructor
Expand Down Expand Up @@ -2441,8 +2462,15 @@ def _parse_decl(
# grouping paren like "void (name(int x));"
toks = self._consume_balanced_tokens(tok)

# .. not sure what it's grouping, so put it back?
self.lex.return_tokens(toks[1:-1])
# check to see if the next token is an arrow, and thus a trailing return
if self.lex.token_peek_if("ARROW"):
self.lex.return_tokens(toks)
# the leading name of the class/ctor has been parsed as a type before the parens
pqname = parsed_type.typename
is_guide = True
else:
# .. not sure what it's grouping, so put it back?
self.lex.return_tokens(toks[1:-1])

if dtype:
msvc_convention = self.lex.token_if_val(*self._msvc_conventions)
Expand Down Expand Up @@ -2473,6 +2501,7 @@ def _parse_decl(
is_friend,
is_typedef,
msvc_convention,
is_guide,
)
elif msvc_convention:
raise self._parse_error(msvc_convention)
Expand Down
9 changes: 9 additions & 0 deletions cxxheaderparser/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .types import (
ClassDecl,
Concept,
DeductionGuide,
EnumDecl,
Field,
ForwardDecl,
Expand Down Expand Up @@ -123,6 +124,9 @@ class NamespaceScope:
#: Child namespaces
namespaces: typing.Dict[str, "NamespaceScope"] = field(default_factory=dict)

#: Deduction guides
deduction_guides: typing.List[DeductionGuide] = field(default_factory=list)


Block = typing.Union[ClassScope, NamespaceScope]

Expand Down Expand Up @@ -317,6 +321,11 @@ def on_class_friend(self, state: SClassBlockState, friend: FriendDecl) -> None:
def on_class_end(self, state: SClassBlockState) -> None:
pass

def on_deduction_guide(
self, state: SNonClassBlockState, guide: DeductionGuide
) -> None:
state.user_data.deduction_guides.append(guide)


def parse_string(
content: str,
Expand Down
18 changes: 18 additions & 0 deletions cxxheaderparser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,21 @@ class UsingAlias:

#: Documentation if present
doxygen: typing.Optional[str] = None


@dataclass
class DeductionGuide:
"""
.. code-block:: c++
template <class T>
MyClass(T) -> MyClass(int);
"""

#: Only constructors and destructors don't have a return type
result_type: typing.Optional[DecoratedType]

name: PQName
parameters: typing.List[Parameter]

doxygen: typing.Optional[str] = None
13 changes: 13 additions & 0 deletions cxxheaderparser/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .types import (
Concept,
DeductionGuide,
EnumDecl,
Field,
ForwardDecl,
Expand Down Expand Up @@ -236,6 +237,13 @@ def on_class_end(self, state: ClassBlockState) -> None:
``on_variable`` for each instance declared.
"""

def on_deduction_guide(
self, state: NonClassBlockState, guide: DeductionGuide
) -> None:
"""
Called when a deduction guide is encountered
"""


class NullVisitor:
"""
Expand Down Expand Up @@ -318,5 +326,10 @@ def on_class_method(self, state: ClassBlockState, method: Method) -> None:
def on_class_end(self, state: ClassBlockState) -> None:
return None

def on_deduction_guide(
self, state: NonClassBlockState, guide: DeductionGuide
) -> None:
return None


null_visitor = NullVisitor()
84 changes: 84 additions & 0 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BaseClass,
ClassDecl,
DecltypeSpecifier,
DeductionGuide,
Field,
ForwardDecl,
Function,
Expand Down Expand Up @@ -2163,3 +2164,86 @@ def test_member_class_template_specialization() -> None:
]
)
)


def test_template_deduction_guide() -> None:
content = """
template <class CharT, class Traits = std::char_traits<CharT>>
Error(std::basic_string_view<CharT, Traits>) -> Error<std::string>;
"""
data = parse_string(content, cleandoc=True)

assert data == ParsedData(
namespace=NamespaceScope(
deduction_guides=[
DeductionGuide(
result_type=Type(
typename=PQName(
segments=[
NameSpecifier(
name="Error",
specialization=TemplateSpecialization(
args=[
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(name="std"),
NameSpecifier(
name="string"
),
]
)
)
)
]
),
)
]
)
),
name=PQName(segments=[NameSpecifier(name="Error")]),
parameters=[
Parameter(
type=Type(
typename=PQName(
segments=[
NameSpecifier(name="std"),
NameSpecifier(
name="basic_string_view",
specialization=TemplateSpecialization(
args=[
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(
name="CharT"
)
]
)
)
),
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(
name="Traits"
)
]
)
)
),
]
),
),
]
)
)
)
],
)
]
)
)

0 comments on commit 29b71ab

Please sign in to comment.