Skip to content

Commit

Permalink
Fix C++ apigen page naming for variable template specializations
Browse files Browse the repository at this point in the history
Previously, the page names for variable template specializations
incorrectly included the template arguments within angle brackets,
unlike class template specializations where such arguments were
stripped.

With this commit, template arguments are stripped and the entities are
distinguished based on the user-specified `id`.
  • Loading branch information
jbms committed Oct 31, 2024
1 parent fcc5f76 commit 0ec412c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sphinx_immaterial/apidoc/cpp/api_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _make_replacement_pattern(
)

SPECIAL_GROUP_COMMAND_PATTERN = re.compile(
r"^(?:\\|@)(ingroup|relates|membergroup|id)\s+(.+[^\s])\s*$", re.MULTILINE
r"^(?:\\|@)(ingroup|relates|membergroup|id)\s+(.*[^\s])\s*$", re.MULTILINE
)


Expand Down Expand Up @@ -2206,7 +2206,9 @@ def _format_template_arguments(entity: CppApiEntity) -> str:

def _get_entity_base_page_name_component(entity: CppApiEntity) -> str:
base_name = entity["name"]
if entity["kind"] == "class" and entity.get("specializes"):
if (entity["kind"] == "class" or entity["kind"] == "var") and entity.get(
"specializes"
):
# Strip any template arguments
base_name = re.sub("([^<]*).*", r"\1", base_name)
elif entity["kind"] == "conversion_function":
Expand Down
43 changes: 41 additions & 2 deletions tests/cpp_api_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def test_basic():
config = api_parser.Config(
input_path="a.cpp",
input_content=b"""
input_content=rb"""
/// This is the doc.
int foo(bool x, int y);
Expand All @@ -17,6 +17,7 @@ def test_basic():
)

output = api_parser.generate_output(config)
assert not output.get("errors")
entities = list(output["entities"].values())
assert len(entities) == 1

Expand Down Expand Up @@ -45,6 +46,7 @@ class DimensionIdentifier {
)

output = api_parser.generate_output(config)
assert not output.get("errors")
print(output)
entities = list(output["entities"].values())
assert len(entities) == 2
Expand All @@ -62,6 +64,8 @@ def test_enable_if_transform():
}
/// This is the doc.
///
/// \ingroup X
template <typename U, typename T, typename = std::enable_if_t<std::is_convertible_v<U(*)[], T(*)[]>>>
int foo(T x);
Expand All @@ -76,7 +80,8 @@ def test_enable_if_transform():
)

output = api_parser.generate_output(config)

assert not output.get("errors")
assert not output.get("warnings")
entities = list(output["entities"].values())
assert len(entities) == 1
requires = entities[0].get("requires")
Expand Down Expand Up @@ -170,6 +175,7 @@ def test_comment_styles(doc_str: bytes, expected: str):
input_content=doc_str,
)
output = api_parser.generate_output(config)
assert not output.get("errors")
doc_strings = [
cast(api_parser.JsonDocComment, v["doc"])["text"]
for v in output.get("entities", {}).values()
Expand Down Expand Up @@ -198,6 +204,7 @@ def test_function_fields():
)

output = api_parser.generate_output(config)
assert not output.get("errors")
entities = output.get("entities", {})
doc_str = ""
for entity in entities.values():
Expand Down Expand Up @@ -238,9 +245,41 @@ def test_unnamed_template_parameter():
)

output = api_parser.generate_output(config)
assert not output.get("errors")
assert not output.get("warnings")
entities = output.get("entities", {})
assert len(entities) == 1
entity = list(entities.values())[0]
tparams = entity["template_parameters"]
assert tparams is not None
assert tparams[0]["name"] == ""


def test_variable_template_specialization():
config = api_parser.Config(
input_path="a.cpp",
compiler_flags=["-std=c++17", "-x", "c++"],
input_content=rb"""
/// Check if it has A.
///
/// \ingroup Array
template <typename T>
constexpr inline bool HasA = false;
/// Specializes HasA for int.
/// \ingroup Array
/// \id int
template <>
constexpr inline bool HasA<int> = true;
""",
)

output = api_parser.generate_output(config)
assert not output.get("errors")
assert not output.get("warnings")
entities = output.get("entities", {})
assert len(entities) == 2
assert sorted([entity["page_name"] for entity in entities.values()]) == [
"HasA",
"HasA-int",
]

0 comments on commit 0ec412c

Please sign in to comment.