Skip to content

Commit

Permalink
feat: Add support for class inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Jun 8, 2020
1 parent e99b57e commit c07a4ca
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 40 deletions.
3 changes: 2 additions & 1 deletion src/pytkdocs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def process_config(config: dict) -> dict:
path = obj_config.pop("path")
filters = obj_config.get("filters", [])
members = obj_config.get("members", set())
inherited_members = obj_config.get("inherited_members", False)
if isinstance(members, list):
members = set(members)
loader = Loader(filters=filters)
loader = Loader(filters=filters, inherited_members=inherited_members)

obj = loader.get_object_documentation(path, members)

Expand Down
107 changes: 71 additions & 36 deletions src/pytkdocs/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import re
import textwrap
from functools import lru_cache
from itertools import chain
from pathlib import Path
from typing import Any, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union

from pytkdocs.objects import Attribute, Class, Function, Method, Module, Object, Source
from pytkdocs.parsers.attributes import get_class_attributes, get_instance_attributes, get_module_attributes
from pytkdocs.parsers.attributes import get_class_attributes, get_instance_attributes, get_module_attributes, merge
from pytkdocs.parsers.docstrings import PARSERS
from pytkdocs.properties import RE_SPECIAL

Expand Down Expand Up @@ -187,6 +188,7 @@ def __init__(
filters: Optional[List[str]] = None,
docstring_style: str = "google",
docstring_options: Optional[dict] = None,
inherited_members: bool = False,
) -> None:
"""
Initialization method.
Expand All @@ -195,13 +197,15 @@ def __init__(
filters: A list of regular expressions to fine-grain select members. It is applied recursively.
docstring_style: The style to use when parsing docstrings.
docstring_options: The options to pass to the docstrings parser.
inherited_members: Whether to select inherited members for classes.
"""
if not filters:
filters = []

self.filters = [(f, re.compile(f.lstrip("!"))) for f in filters]
self.docstring_parser = PARSERS[docstring_style](**(docstring_options or {})) # type: ignore
self.errors: List[str] = []
self.select_inherited_members = inherited_members

def get_object_documentation(self, dotted_path: str, members: Optional[Union[Set[str], bool]] = None) -> Object:
"""
Expand Down Expand Up @@ -243,13 +247,13 @@ def get_object_documentation(self, dotted_path: str, members: Optional[Union[Set

return root_object

def get_module_documentation(self, node: ObjectNode, members=None) -> Module:
def get_module_documentation(self, node: ObjectNode, select_members=None) -> Module:
"""
Get the documentation for a module and its children.
Arguments:
node: The node representing the module and its parents.
members: Explicit members to select.
select_members: Explicit members to select.
Return:
The documented module object.
Expand Down Expand Up @@ -277,17 +281,17 @@ def get_module_documentation(self, node: ObjectNode, members=None) -> Module:
name=name, path=path, file_path=node.file_path, docstring=inspect.getdoc(module), source=source
)

if members is False:
if select_members is False:
return root_object

# type_hints = get_type_hints(module)
members = members or set()
select_members = select_members or set()

attributes_data = get_module_attributes(module)
root_object.parse_docstring(self.docstring_parser, attributes=attributes_data)

for member_name, member in inspect.getmembers(module):
if self.select(member_name, members): # type: ignore
if self.select(member_name, select_members): # type: ignore
child_node = ObjectNode(member, member_name, parent=node)
if child_node.is_class() and node.root.obj is inspect.getmodule(member):
root_object.add_child(self.get_class_documentation(child_node))
Expand All @@ -302,19 +306,19 @@ def get_module_documentation(self, node: ObjectNode, members=None) -> Module:
pass
else:
for _, modname, _ in pkgutil.iter_modules(package_path):
if self.select(modname, members):
if self.select(modname, select_members):
leaf = get_object_tree(f"{path}.{modname}")
root_object.add_child(self.get_module_documentation(leaf))

return root_object

def get_class_documentation(self, node: ObjectNode, members=None) -> Class:
def get_class_documentation(self, node: ObjectNode, select_members=None) -> Class:
"""
Get the documentation for a class and its children.
Arguments:
node: The node representing the class and its parents.
members: Explicit members to select.
select_members: Explicit members to select.
Return:
The documented class object.
Expand All @@ -323,52 +327,83 @@ def get_class_documentation(self, node: ObjectNode, members=None) -> Class:
docstring = textwrap.dedent(class_.__doc__ or "")
root_object = Class(name=node.name, path=node.dotted_path, file_path=node.file_path, docstring=docstring)

if members is False:
return root_object

members = members or set()

attributes_data = get_class_attributes(class_)
context = {"attributes": attributes_data}
# Even if we don't select members, we want to correctly parse the docstring
attributes_data: Dict[str, Dict[str, Any]] = {}
for cls in reversed(class_.__mro__[:-1]):
merge(attributes_data, get_class_attributes(cls))
context: Dict[str, Any] = {"attributes": attributes_data}
if "__init__" in class_.__dict__:
attributes_data.update(get_instance_attributes(class_.__init__))
context["signature"] = inspect.signature(class_.__init__)
root_object.parse_docstring(self.docstring_parser, attributes=attributes_data)

for member_name, member in class_.__dict__.items():
if member is type or member is object:
continue

if not self.select(member_name, members): # type: ignore
continue
if select_members is False:
return root_object

child_node = ObjectNode(getattr(class_, member_name), member_name, parent=node)
select_members = select_members or set()

# Build the list of members
members = {}
inherited = set()
direct_members = class_.__dict__
all_members = dict(inspect.getmembers(class_))
for member_name, member in all_members.items():
if not (member is type or member is object) and self.select(member_name, select_members):
if member_name not in direct_members:
if self.select_inherited_members:
members[member_name] = member
inherited.add(member_name)
else:
members[member_name] = member

# Iterate on the selected members
child: Object
for member_name, member in members.items():
child_node = ObjectNode(member, member_name, parent=node)
if child_node.is_class():
root_object.add_child(self.get_class_documentation(child_node))
child = self.get_class_documentation(child_node)
elif child_node.is_classmethod():
root_object.add_child(self.get_classmethod_documentation(child_node))
child = self.get_classmethod_documentation(child_node)
elif child_node.is_staticmethod():
root_object.add_child(self.get_staticmethod_documentation(child_node))
child = self.get_staticmethod_documentation(child_node)
elif child_node.is_method():
root_object.add_child(self.get_regular_method_documentation(child_node))
child = self.get_regular_method_documentation(child_node)
elif child_node.is_property():
root_object.add_child(self.get_property_documentation(child_node))
child = self.get_property_documentation(child_node)
elif member_name in attributes_data:
root_object.add_child(self.get_attribute_documentation(child_node, attributes_data[member_name]))
child = self.get_attribute_documentation(child_node, attributes_data[member_name])
else:
continue
if member_name in inherited:
child.properties.append("inherited")
root_object.add_child(child)

# First check if this is Pydantic compatible
if "__fields__" in class_.__dict__:
if "__fields__" in direct_members or (self.select_inherited_members and "__fields__" in all_members):
root_object.properties = ["pydantic"]
for field_name, model_field in class_.__dict__.get("__fields__", {}).items():
if self.select(field_name, members): # type: ignore
for field_name, model_field in all_members["__fields__"].items():
if self.select(field_name, select_members) and ( # type: ignore
self.select_inherited_members
# When we don't select inherited members, one way to tell if a field was inherited
# is to check if it exists in parent classes __fields__ attributes.
# We don't check the current class, nor the top one (object), hence __mro__[1:-1]
or field_name not in chain(*(getattr(cls, "__fields__", {}).keys() for cls in class_.__mro__[1:-1]))
):
child_node = ObjectNode(obj=model_field, name=field_name, parent=node)
root_object.add_child(self.get_pydantic_field_documentation(child_node))

# Handle dataclasses
elif "__dataclass_fields__" in class_.__dict__:
elif "__dataclass_fields__" in direct_members or (
self.select_inherited_members and "__fields__" in all_members
):
root_object.properties = ["dataclass"]
for field_name, annotation in class_.__dict__.get("__annotations__", {}).items():
if self.select(field_name, members): # type: ignore
for field_name, annotation in all_members["__annotations__"].items():
if self.select(field_name, select_members) and ( # type: ignore
self.select_inherited_members
# Same comment as for Pydantic models
or field_name
not in chain(*(getattr(cls, "__dataclass_fields__", {}).keys() for cls in class_.__mro__[1:-1]))
):
child_node = ObjectNode(obj=annotation, name=field_name, parent=node)
root_object.add_child(self.get_annotated_dataclass_field(child_node))

Expand Down
21 changes: 19 additions & 2 deletions src/pytkdocs/parsers/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def node_to_annotation(node) -> Union[str, object]:
def get_nodes(obj):
try:
source = inspect.getsource(obj)
except OSError:
except (OSError, TypeError):
source = ""
return ast.parse(dedent(source)).body

Expand Down Expand Up @@ -97,6 +97,17 @@ def combine(docstrings, type_hints):
}


def merge(base, extra):
for attr_name, data in extra.items():
if attr_name not in base:
base[attr_name] = data
else:
if data["annotation"] is not inspect.Signature.empty:
base[attr_name]["annotation"] = data["annotation"]
if data["docstring"] is not None:
base[attr_name]["docstring"] = data["docstring"]


@lru_cache()
def get_module_attributes(module):
return combine(get_module_or_class_attributes(get_nodes(module)), get_type_hints(module))
Expand All @@ -107,7 +118,13 @@ def get_class_attributes(cls):
nodes = get_nodes(cls)
if not nodes:
return {}
return combine(get_module_or_class_attributes(nodes[0].body), get_type_hints(cls))
try:
type_hints = get_type_hints(cls)
except NameError:
# The __config__ attribute (a class) of Pydantic models trigger this error:
# NameError: name 'SchemaExtraCallable' is not defined
type_hints = {}
return combine(get_module_or_class_attributes(nodes[0].body), type_hints)


def pick_target(target):
Expand Down
27 changes: 27 additions & 0 deletions tests/fixtures/inherited_members.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel as PydanticModel


class Base:
V1 = "v1"
"""Variable 1."""

def method1(self):
"""Method 1."""
pass


class Child(Base):
V2 = "v2"
"""Variable 2."""

def method2(self):
"""Method 2."""
pass


class BaseModel(PydanticModel):
a: int


class ChildModel(BaseModel):
b: str
42 changes: 41 additions & 1 deletion tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def test_loading_pydantic_model():
obj = loader.get_object_documentation("tests.fixtures.pydantic.Person")
assert obj.docstring == "Simple Pydantic Model for a person's information"
assert "pydantic" in obj.properties
assert len(obj.attributes) == 2
name_attr = next(attr for attr in obj.attributes if attr.name == "name")
assert name_attr.type == str
assert name_attr.docstring == "The person's name"
Expand Down Expand Up @@ -312,3 +311,44 @@ def test_loading_members_set_at_import_time():
assert len(obj.classes) == 1
class_ = obj.classes[0]
assert class_.methods


def test_loading_inherited_members():
"""Select inherited members."""
loader = Loader(inherited_members=True)
obj = loader.get_object_documentation("tests.fixtures.inherited_members.Child")
for child_name in ("method1", "method2", "V1", "V2"):
assert child_name in (child.name for child in obj.children)


def test_not_loading_inherited_members():
"""Do not select inherited members."""
loader = Loader(inherited_members=False)
obj = loader.get_object_documentation("tests.fixtures.inherited_members.Child")
for child_name in ("method1", "V1"):
assert child_name not in (child.name for child in obj.children)
for child_name in ("method2", "V2"):
assert child_name in (child.name for child in obj.children)


def test_loading_selected_inherited_members():
"""Select specific members, some of them being inherited."""
loader = Loader(inherited_members=True)
obj = loader.get_object_documentation("tests.fixtures.inherited_members.Child", members={"V1", "V2"})
for child_name in ("V1", "V2"):
assert child_name in (child.name for child in obj.children)


def test_loading_pydantic_inherited_members():
"""Select inherited members in Pydantic models."""
loader = Loader(inherited_members=True)
obj = loader.get_object_documentation("tests.fixtures.inherited_members.ChildModel")
for child_name in ("a", "b"):
assert child_name in (child.name for child in obj.children)


def test_not_loading_pydantic_inherited_members():
"""Do not select inherited members in Pydantic models."""
loader = Loader(inherited_members=False)
obj = loader.get_object_documentation("tests.fixtures.inherited_members.ChildModel")
assert "a" not in (child.name for child in obj.children)

0 comments on commit c07a4ca

Please sign in to comment.