Skip to content

Commit

Permalink
Fix Mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
charmoniumQ committed Jul 17, 2024
1 parent bbd3133 commit b05f3f2
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 52 deletions.
25 changes: 20 additions & 5 deletions probe_src/arena/parse_arena.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import pathlib
import ctypes
from typing import Sequence
from typing import Sequence, Iterator, overload

@dataclasses.dataclass(frozen=True)
class MemorySegment:
Expand All @@ -23,6 +23,12 @@ def _check(self) -> None:
def length(self) -> int:
return self.stop - self.start

@overload
def __getitem__(self, idx: slice) -> bytes: ...

@overload
def __getitem__(self, idx: int) -> int: ...

def __getitem__(self, idx: slice | int) -> bytes | int:
if isinstance(idx, slice):
if not (self.start <= idx.start <= idx.stop <= self.stop):
Expand Down Expand Up @@ -58,6 +64,12 @@ def __post_init__(self) -> None:
def _check(self) -> None:
assert sorted(self.segments, key=lambda segment: segment.start) == self.segments

@overload
def __getitem__(self, idx: slice) -> bytes: ...

@overload
def __getitem__(self, idx: int) -> int: ...

def __getitem__(self, idx: slice | int) -> bytes | int:
if isinstance(idx, slice):
buffr = b''
Expand All @@ -76,6 +88,9 @@ def __getitem__(self, idx: slice | int) -> bytes | int:
def __contains__(self, idx: int) -> bool:
return any(idx in segment for segment in self.segments)

def __iter__(self) -> Iterator[MemorySegment]:
return iter(self.segments)


class CArena(ctypes.Structure):
_fields_ = [
Expand All @@ -93,19 +108,19 @@ def parse_arena_buffer(buffr: bytes) -> MemorySegment:
return MemorySegment(buffr[ctypes.sizeof(CArena) : c_arena.used], start, stop)


def parse_arena_dir(arena_dir: pathlib.Path) -> Sequence[MemorySegment]:
def parse_arena_dir(arena_dir: pathlib.Path) -> MemorySegments:
memory_segments = []
for path in sorted(arena_dir.iterdir()):
assert path.name.endswith(".dat")
buffr = path.read_bytes()
memory_segments.append(parse_arena_buffer(buffr))
return memory_segments
return MemorySegments(memory_segments)


def parse_arena_dir_tar(
arena_dir_tar: tarfile.TarFile,
prefix: pathlib.Path = pathlib.Path(),
) -> Sequence[MemorySegment]:
) -> MemorySegments:
memory_segments = []
for member in sorted(arena_dir_tar, key=lambda member: member.name):
member_path = pathlib.Path(member.name)
Expand All @@ -116,7 +131,7 @@ def parse_arena_dir_tar(
buffr = extracted.read()
memory_segment = parse_arena_buffer(buffr)
memory_segments.append(memory_segment)
return memory_segments
return MemorySegments(memory_segments)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions probe_src/probe_py/analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
import networkx as nx
import networkx as nx # type: ignore
from .parse_probe_log import ProvLog, CloneOp, ExecOp, WaitOp, OpenOp, CloseOp ,CLONE_THREAD
from enum import IntEnum

Expand Down Expand Up @@ -96,7 +96,7 @@ def last(pid: int, exid: int, tid: int) -> Node:
for node in nodes:
process_graph.add_node(node)

def add_edges(edges:list[tuple[Node, Node]], label:EdgeLabels):
def add_edges(edges:list[tuple[Node, Node]], label:EdgeLabels) -> None:
for node0, node1 in edges:
process_graph.add_edge(node0, node1, label=label)

Expand All @@ -117,7 +117,7 @@ def digraph_to_pydot_string(process_graph: nx.DiGraph) -> str:
label:EdgeLabels = attrs['label']
process_graph[node0][node1]['color'] = label_color_map[label]
pydot_graph = nx.drawing.nx_pydot.to_pydot(process_graph)
dot_string = pydot_graph.to_string()
dot_string = typing.cast(str, pydot_graph.to_string())
return dot_string


Expand Down
4 changes: 2 additions & 2 deletions probe_src/probe_py/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def record(
debug: bool = typer.Option(default=False, help="Run verbose & debug build of libprobe"),
make: bool = typer.Option(default=False, help="Run make prior to executing"),
output: pathlib.Path = pathlib.Path("probe_log"),
):
) -> None:
"""
Execute CMD... and record its provenance into OUTPUT.
"""
Expand Down Expand Up @@ -108,7 +108,7 @@ def process_graph(
@app.command()
def dump(
input: pathlib.Path = pathlib.Path("probe_log"),
):
) -> None:
"""
Write the data from PROBE_LOG in a human-readable manner.
"""
Expand Down
40 changes: 27 additions & 13 deletions probe_src/probe_py/parse_probe_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,36 @@
# echo '#define _GNU_SOURCE\n#include <sched.h>\nCLONE_THREAD' | cpp | tail --lines=1
CLONE_THREAD = 0x00010000

COp = c_types[("struct", "Op")]
Op: typing.TypeAlias = py_types[("struct", "Op")]
InitExecEpochOp: typing.TypeAlias = py_types[("struct", "InitExecEpochOp")]
InitThreadOp: typing.TypeAlias = py_types[("struct", "InitThreadOp")]
CloneOp: typing.TypeAlias = py_types[("struct", "CloneOp")]
ExecOp: typing.TypeAlias = py_types[("struct", "ExecOp")]
WaitOp: typing.TypeAlias = py_types[("struct", "WaitOp")]
OpenOp: typing.TypeAlias = py_types[("struct", "OpenOp")]
CloseOp: typing.TypeAlias = py_types[("struct", "CloseOp")]
OpCode: enum.EnumType = py_types[("enum", "OpCode")]
TaskType: enum.EnumType = py_types[("enum", "TaskType")]

if typing.TYPE_CHECKING:
COp: typing.Any = object
Op: typing.Any = object
InitExecEpochOp: typing.Any = object
InitThreadOp: typing.Any = object
CloneOp: typing.Any = object
ExecOp: typing.Any = object
WaitOp: typing.Any = object
OpenOp: typing.Any = object
CloseOp: typing.Any = object
OpCode: typing.Any = object
TaskType: typing.Any = object
else:
COp = c_types[("struct", "Op")]
Op: typing.TypeAlias = py_types[("struct", "Op")]
InitExecEpochOp: typing.TypeAlias = py_types[("struct", "InitExecEpochOp")]
InitThreadOp: typing.TypeAlias = py_types[("struct", "InitThreadOp")]
CloneOp: typing.TypeAlias = py_types[("struct", "CloneOp")]
ExecOp: typing.TypeAlias = py_types[("struct", "ExecOp")]
WaitOp: typing.TypeAlias = py_types[("struct", "WaitOp")]
OpenOp: typing.TypeAlias = py_types[("struct", "OpenOp")]
CloseOp: typing.TypeAlias = py_types[("struct", "CloseOp")]
OpCode: enum.EnumType = py_types[("enum", "OpCode")]
TaskType: enum.EnumType = py_types[("enum", "TaskType")]

@dataclasses.dataclass
class ThreadProvLog:
tid: int
ops: typing.Sequence[Op] # type: ignore
ops: typing.Sequence[Op]


@dataclasses.dataclass
Expand Down Expand Up @@ -113,7 +127,7 @@ def parse_probe_log_tar(probe_log_tar: tarfile.TarFile) -> ProvLog:
])
threads = collections.defaultdict[int, dict[int, dict[int, ThreadProvLog]]](
lambda: collections.defaultdict[int, dict[int, ThreadProvLog]](
collections.defaultdict[dict[int, ThreadProvLog]]
dict[int, ThreadProvLog]
)
)
for member in member_paths:
Expand Down
44 changes: 26 additions & 18 deletions probe_src/probe_py/struct_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
import pycparser # type: ignore


_T = typing.TypeVar("_T")

# CType: typing.TypeAlias = type[ctypes._CData]
Expand Down Expand Up @@ -97,7 +98,7 @@ def _normalize_name(name: tuple[str, ...]) -> tuple[str, ...]:
def int_representing_pointer(inner_c_type: CType) -> CType:
class PointerStruct(ctypes.Structure):
_fields_ = [("value", ctypes.c_ulong)]
PointerStruct.inner_c_type = inner_c_type
PointerStruct.inner_c_type = inner_c_type # type: ignore
return PointerStruct


Expand All @@ -118,7 +119,12 @@ def _lookup_type(
return c_type, py_type


def eval_compile_time_int(c_types, py_types, typ: pycparser.c_ast.Node, name: str) -> int | Exception:
def eval_compile_time_int(
c_types: CTypeDict,
py_types: PyTypeDict,
typ: pycparser.c_ast.Node,
name: str,
) -> int | Exception:
if False:
pass
elif isinstance(typ, pycparser.c_ast.Constant):
Expand All @@ -134,7 +140,7 @@ def eval_compile_time_int(c_types, py_types, typ: pycparser.c_ast.Node, name: st
else:
return ctypes.sizeof(c_type)
else:
return eval(f"{typ.op} {eval_compile_time_int(c_types, py_types, typ.expr, name)}")
return int(eval(f"{typ.op} {eval_compile_time_int(c_types, py_types, typ.expr, name)}"))
elif isinstance(typ, pycparser.c_ast.BinaryOp):
left = eval_compile_time_int(c_types, py_types, typ.left, name + "_left")
right = eval_compile_time_int(c_types, py_types, typ.right, name + "_right")
Expand Down Expand Up @@ -172,10 +178,11 @@ def ast_to_cpy_type(
if isinstance(inner_py_type, Exception):
c_type = inner_py_type
else:
py_type: type[object]
if inner_c_type == ctypes.c_char:
py_type = str
else:
py_type = list[inner_py_type]
py_type = list[inner_py_type] # type: ignore
return c_type, py_type
elif isinstance(typ, pycparser.c_ast.ArrayDecl):
repetitions = eval_compile_time_int(c_types, py_types, typ.dim, name)
Expand All @@ -189,7 +196,7 @@ def ast_to_cpy_type(
if isinstance(inner_py_type, Exception):
array_py_type = inner_py_type
else:
array_py_type = tuple[(inner_py_type,)]
array_py_type = tuple[(inner_py_type,)] # type: ignore
return array_c_type, array_py_type
elif isinstance(typ, pycparser.c_ast.Enum):
if typ.values is None:
Expand Down Expand Up @@ -358,7 +365,8 @@ def c_type_to_c_source(c_type: CType, top_level: bool = True) -> str:
elif isinstance(c_type, CArrayType):
return c_type_to_c_source(c_type._type_, False) + "[" + str(c_type._length_) + "]"
elif isinstance(c_type, type(ctypes._Pointer)):
return c_type_to_c_source(c_type._type_, False) + "*"
typ: ctypes._CData = c_type._type_ # type: ignore
return c_type_to_c_source(typ, False) + "*"
elif isinstance(c_type, type(ctypes._SimpleCData)):
name = c_type.__name__
return {
Expand Down Expand Up @@ -416,25 +424,25 @@ def convert_c_obj_to_py_obj(
info: typing.Any,
memory: MemoryMapping,
depth: int = 0,
) -> PyType:
) -> PyType | None:
if verbose:
print(depth * " ", c_obj, py_type, info)
if False:
pass
elif c_obj.__class__.__name__ == "PointerStruct":
assert py_type.__name__ == "list" or py_type is str, (type(c_obj), py_type)
if py_type.__name__ == "list":
inner_py_type = py_type.__args__[0]
inner_py_type = py_type.__args__[0] # type: ignore
else:
inner_py_type = str
inner_c_type = c_obj.inner_c_type
size = ctypes.sizeof(inner_c_type)
pointer_int = _expect_type(int, c_obj.value)
if pointer_int == 0:
return None
return None
if pointer_int not in memory:
raise ValueError(f"Pointer {pointer_int:08x} is outside of memory {memory!s}")
lst: inner_py_type = []
lst: inner_py_type = [] # type: ignore
while True:
cont, sub_info = (memory[pointer_int : pointer_int + 1] != b'\0', None) if info is None else info[0](memory, pointer_int)
if cont:
Expand All @@ -446,12 +454,12 @@ def convert_c_obj_to_py_obj(
memory,
depth + 1,
)
lst.append(inner_py_obj)
lst.append(inner_py_obj) # type: ignore
pointer_int += size
else:
break
if py_type is str:
return "".join(lst)
return "".join(lst) # type: ignore
else:
return lst
elif isinstance(c_obj, ctypes.Array):
Expand Down Expand Up @@ -482,7 +490,7 @@ def convert_c_obj_to_py_obj(
memory,
depth + 1,
)
return py_type(**fields)
return py_type(**fields) # type: ignore
elif isinstance(c_obj, ctypes.Union):
if not dataclasses.is_dataclass(py_type):
raise TypeError(f"If {type(c_obj)} is a union, then {py_type} should be a dataclass")
Expand All @@ -501,16 +509,16 @@ def convert_c_obj_to_py_obj(
elif isinstance(c_obj, ctypes._SimpleCData):
if isinstance(py_type, enum.EnumType):
assert isinstance(c_obj.value, int)
return py_type(c_obj.value)
return py_type(c_obj.value) # type: ignore
elif py_type is str:
assert isinstance(c_obj, ctypes.c_char)
return c_obj.value.decode()
return c_obj.value.decode() # type: ignore
else:
ret = c_obj.value
return _expect_type(py_type, ret)
return _expect_type(py_type, ret) # type: ignore
elif isinstance(c_obj, py_type):
return c_obj
return c_obj # type: ignore
elif isinstance(c_obj, int) and isinstance(py_type, enum.EnumType):
return py_type(c_obj)
return py_type(c_obj) # type: ignore
else:
raise TypeError(f"{c_obj!r} of c_type {type(c_obj)!r} cannot be converted to py_type {py_type!r}")
Loading

0 comments on commit b05f3f2

Please sign in to comment.