From 04a45535a62102a7a542bdf0b6366bc69a6e0657 Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Sun, 17 Oct 2021 16:13:16 +0200 Subject: [PATCH] Add type hints This commit only adds type hints and comments and does not make any changes that should affect runtime. The type hints added here derive from work done for #1418. --- rdflib/graph.py | 161 ++++++++++++------- rdflib/parser.py | 10 +- rdflib/plugin.py | 62 +++++-- rdflib/plugins/serializers/jsonld.py | 11 +- rdflib/plugins/serializers/n3.py | 2 +- rdflib/plugins/serializers/nquads.py | 15 +- rdflib/plugins/serializers/nt.py | 15 +- rdflib/plugins/serializers/rdfxml.py | 46 ++++-- rdflib/plugins/serializers/trig.py | 19 ++- rdflib/plugins/serializers/trix.py | 11 +- rdflib/plugins/sparql/results/csvresults.py | 9 +- rdflib/plugins/sparql/results/jsonresults.py | 5 +- rdflib/plugins/sparql/results/txtresults.py | 15 +- rdflib/plugins/sparql/results/xmlresults.py | 15 +- rdflib/query.py | 37 +++-- rdflib/serializer.py | 24 ++- rdflib/store.py | 22 ++- rdflib/term.py | 8 +- 18 files changed, 344 insertions(+), 143 deletions(-) diff --git a/rdflib/graph.py b/rdflib/graph.py index 805bb7c64c..ba462d544c 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1,4 +1,15 @@ -from typing import Optional, Union, Type, cast, overload, Generator, Tuple +from typing import ( + IO, + Any, + Iterable, + Optional, + Union, + Type, + cast, + overload, + Generator, + Tuple, +) import logging from warnings import warn import random @@ -21,7 +32,7 @@ import tempfile import pathlib -from io import BytesIO, BufferedIOBase +from io import BytesIO from urllib.parse import urlparse assert Literal # avoid warning @@ -313,15 +324,19 @@ class Graph(Node): """ def __init__( - self, store="default", identifier=None, namespace_manager=None, base=None + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + namespace_manager: Optional[NamespaceManager] = None, + base: Optional[str] = None, ): super(Graph, self).__init__() self.base = base - self.__identifier = identifier or BNode() - + self.__identifier: Node + self.__identifier = identifier or BNode() # type: ignore[assignment] if not isinstance(self.__identifier, Node): - self.__identifier = URIRef(self.__identifier) - + self.__identifier = URIRef(self.__identifier) # type: ignore[unreachable] + self.__store: Store if not isinstance(store, Store): # TODO: error handling self.__store = store = plugin.get(store, Store)() @@ -404,7 +419,7 @@ def close(self, commit_pending_transaction=False): """ return self.__store.close(commit_pending_transaction=commit_pending_transaction) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -413,7 +428,7 @@ def add(self, triple): self.__store.add((s, p, o), self, quoted=False) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triple with context""" self.__store.addN( @@ -434,7 +449,9 @@ def remove(self, triple): self.__store.remove(triple, context=self) return self - def triples(self, triple): + def triples( + self, triple: Tuple[Optional[Node], Union[None, Path, Node], Optional[Node]] + ): """Generator over the triple store Returns triples that match the given triple pattern. If triple pattern @@ -652,17 +669,17 @@ def set(self, triple): self.add((subject, predicate, object_)) return self - def subjects(self, predicate=None, object=None): + def subjects(self, predicate=None, object=None) -> Iterable[Node]: """A generator of subjects with the given predicate and object""" for s, p, o in self.triples((None, predicate, object)): yield s - def predicates(self, subject=None, object=None): + def predicates(self, subject=None, object=None) -> Iterable[Node]: """A generator of predicates with the given subject and object""" for s, p, o in self.triples((subject, None, object)): yield p - def objects(self, subject=None, predicate=None): + def objects(self, subject=None, predicate=None) -> Iterable[Node]: """A generator of objects with the given subject and predicate""" for s, p, o in self.triples((subject, predicate, None)): yield o @@ -1019,45 +1036,32 @@ def serialize( @overload def serialize( self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., + *, encoding: str, **args, ) -> bytes: ... - # no destination and None positional encoding - @overload - def serialize( - self, - destination: None, - format: str, - base: Optional[str], - encoding: None, - **args, - ) -> str: - ... - - # no destination and None keyword encoding + # no destination and None encoding @overload def serialize( self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., - encoding: None = None, + encoding: None = ..., **args, ) -> str: ... - # non-none destination + # non-None destination @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath], + destination: Union[str, pathlib.PurePath, IO[bytes]], format: str = ..., base: Optional[str] = ..., encoding: Optional[str] = ..., @@ -1069,21 +1073,21 @@ def serialize( @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, - format: str = "turtle", - base: Optional[str] = None, - encoding: Optional[str] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = ..., + format: str = ..., + base: Optional[str] = ..., + encoding: Optional[str] = ..., **args, ) -> Union[bytes, str, "Graph"]: ... def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = None, format: str = "turtle", base: Optional[str] = None, encoding: Optional[str] = None, - **args, + **args: Any, ) -> Union[bytes, str, "Graph"]: """Serialize the Graph to destination @@ -1104,7 +1108,7 @@ def serialize( base = self.base serializer = plugin.get(format, Serializer)(self) - stream: BufferedIOBase + stream: IO[bytes] if destination is None: stream = BytesIO() if encoding is None: @@ -1114,7 +1118,7 @@ def serialize( serializer.serialize(stream, base=base, encoding=encoding, **args) return stream.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, base=base, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): @@ -1149,10 +1153,10 @@ def parse( self, source=None, publicID=None, - format=None, + format: Optional[str] = None, location=None, file=None, - data=None, + data: Optional[Union[str, bytes, bytearray]] = None, **args, ): """ @@ -1537,7 +1541,12 @@ class ConjunctiveGraph(Graph): All queries are carried out against the union of all graphs. """ - def __init__(self, store="default", identifier=None, default_graph_base=None): + def __init__( + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + default_graph_base: Optional[str] = None, + ): super(ConjunctiveGraph, self).__init__(store, identifier=identifier) assert self.store.context_aware, ( "ConjunctiveGraph must be backed by" " a context aware store." @@ -1555,7 +1564,31 @@ def __str__(self): ) return pattern % self.store.__class__.__name__ - def _spoc(self, triple_or_quad, default=False): + @overload + def _spoc( + self, + triple_or_quad: Union[ + Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node] + ], + default: bool = False, + ) -> Tuple[Node, Node, Node, Optional[Graph]]: + ... + + @overload + def _spoc( + self, + triple_or_quad: None, + default: bool = False, + ) -> Tuple[None, None, None, Optional[Graph]]: + ... + + def _spoc( + self, + triple_or_quad: Optional[ + Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]] + ], + default: bool = False, + ) -> Tuple[Optional[Node], Optional[Node], Optional[Node], Optional[Graph]]: """ helper method for having methods that support either triples or quads @@ -1564,9 +1597,9 @@ def _spoc(self, triple_or_quad, default=False): return (None, None, None, self.default_context if default else None) if len(triple_or_quad) == 3: c = self.default_context if default else None - (s, p, o) = triple_or_quad + (s, p, o) = triple_or_quad # type: ignore[misc] elif len(triple_or_quad) == 4: - (s, p, o, c) = triple_or_quad + (s, p, o, c) = triple_or_quad # type: ignore[misc] c = self._graph(c) return s, p, o, c @@ -1577,7 +1610,7 @@ def __contains__(self, triple_or_quad): return True return False - def add(self, triple_or_quad): + def add(self, triple_or_quad: Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]) -> "ConjunctiveGraph": # type: ignore[override] """ Add a triple or quad to the store. @@ -1591,7 +1624,15 @@ def add(self, triple_or_quad): self.store.add((s, p, o), context=c, quoted=False) return self - def _graph(self, c): + @overload + def _graph(self, c: Union[Graph, Node, str]) -> Graph: + ... + + @overload + def _graph(self, c: None) -> None: + ... + + def _graph(self, c: Optional[Union[Graph, Node, str]]) -> Optional[Graph]: if c is None: return None if not isinstance(c, Graph): @@ -1599,7 +1640,7 @@ def _graph(self, c): else: return c - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triples with context""" self.store.addN( @@ -1689,13 +1730,19 @@ def contexts(self, triple=None): else: yield self.get_context(context) - def get_context(self, identifier, quoted=False, base=None): + def get_context( + self, + identifier: Optional[Union[Node, str]], + quoted: bool = False, + base: Optional[str] = None, + ) -> Graph: """Return a context graph for the given identifier identifier must be a URIRef or BNode. """ + # TODO: FIXME - why is ConjunctiveGraph passed as namespace_manager? return Graph( - store=self.store, identifier=identifier, namespace_manager=self, base=base + store=self.store, identifier=identifier, namespace_manager=self, base=base # type: ignore[arg-type] ) def remove_context(self, context): @@ -1747,6 +1794,7 @@ def parse( context = Graph(store=self.store, identifier=g_id) context.remove((None, None, None)) # hmm ? context.parse(source, publicID=publicID, format=format, **args) + # TODO: FIXME: This should not return context, but self. return context def __reduce__(self): @@ -1977,7 +2025,7 @@ class QuotedGraph(Graph): def __init__(self, store, identifier): super(QuotedGraph, self).__init__(store, identifier) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -1987,7 +2035,7 @@ def add(self, triple): self.store.add((s, p, o), self, quoted=True) return self - def addN(self, quads): + def addN(self, quads: Tuple[Node, Node, Node, Any]) -> "QuotedGraph": # type: ignore[override] """Add a sequence of triple with context""" self.store.addN( @@ -2261,7 +2309,7 @@ class BatchAddGraph(object): """ - def __init__(self, graph, batch_size=1000, batch_addn=False): + def __init__(self, graph: Graph, batch_size: int = 1000, batch_addn: bool = False): if not batch_size or batch_size < 2: raise ValueError("batch_size must be a positive number") self.graph = graph @@ -2278,7 +2326,10 @@ def reset(self): self.count = 0 return self - def add(self, triple_or_quad): + def add( + self, + triple_or_quad: Union[Tuple[Node, Node, Node], Tuple[Node, Node, Node, Any]], + ) -> "BatchAddGraph": """ Add a triple to the buffer @@ -2294,7 +2345,7 @@ def add(self, triple_or_quad): self.batch.append(triple_or_quad) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): if self.__batch_addn: for q in quads: self.add(q) diff --git a/rdflib/parser.py b/rdflib/parser.py index f0014150f6..1f8a490cde 100644 --- a/rdflib/parser.py +++ b/rdflib/parser.py @@ -16,6 +16,7 @@ import sys from io import BytesIO, TextIOBase, TextIOWrapper, StringIO, BufferedIOBase +from typing import Optional, Union from urllib.request import Request from urllib.request import url2pathname @@ -44,7 +45,7 @@ class Parser(object): def __init__(self): pass - def parse(self, source, sink): + def parse(self, source, sink, **args): pass @@ -214,7 +215,12 @@ def __repr__(self): def create_input_source( - source=None, publicID=None, location=None, file=None, data=None, format=None + source=None, + publicID=None, + location=None, + file=None, + data: Optional[Union[str, bytes, bytearray]] = None, + format=None, ): """ Return an appropriate InputSource instance for the given diff --git a/rdflib/plugin.py b/rdflib/plugin.py index 719c7eaf55..ac3a7fbd06 100644 --- a/rdflib/plugin.py +++ b/rdflib/plugin.py @@ -36,7 +36,21 @@ UpdateProcessor, ) from rdflib.exceptions import Error -from typing import Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterator, + Optional, + Tuple, + Type, + TypeVar, + overload, +) + +if TYPE_CHECKING: + from pkg_resources import EntryPoint __all__ = ["register", "get", "plugins", "PluginException", "Plugin", "PKGPlugin"] @@ -51,42 +65,47 @@ "rdf.plugins.updateprocessor": UpdateProcessor, } -_plugins = {} +_plugins: Dict[Tuple[str, Type[Any]], "Plugin"] = {} class PluginException(Error): pass -class Plugin(object): - def __init__(self, name, kind, module_path, class_name): +PluginT = TypeVar("PluginT") + + +class Plugin(Generic[PluginT]): + def __init__( + self, name: str, kind: Type[PluginT], module_path: str, class_name: str + ): self.name = name self.kind = kind self.module_path = module_path self.class_name = class_name - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: module = __import__(self.module_path, globals(), locals(), [""]) self._class = getattr(module, self.class_name) return self._class -class PKGPlugin(Plugin): - def __init__(self, name, kind, ep): +class PKGPlugin(Plugin[PluginT]): + def __init__(self, name: str, kind: Type[PluginT], ep: "EntryPoint"): self.name = name self.kind = kind self.ep = ep - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: self._class = self.ep.load() return self._class -def register(name: str, kind, module_path, class_name): +def register(name: str, kind: Type[Any], module_path, class_name): """ Register the plugin for (name, kind). The module_path and class_name should be the path to a plugin class. @@ -95,16 +114,13 @@ def register(name: str, kind, module_path, class_name): _plugins[(name, kind)] = p -PluginT = TypeVar("PluginT") - - def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: """ Return the class for the specified (name, kind). Raises a PluginException if unable to do so. """ try: - p = _plugins[(name, kind)] + p: Plugin[PluginT] = _plugins[(name, kind)] except KeyError: raise PluginException("No plugin registered for (%s, %s)" % (name, kind)) return p.getClass() @@ -121,7 +137,21 @@ def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: _plugins[(ep.name, kind)] = PKGPlugin(ep.name, kind, ep) -def plugins(name=None, kind=None): +@overload +def plugins( + name: Optional[str] = ..., kind: Type[PluginT] = ... +) -> Iterator[Plugin[PluginT]]: + ... + + +@overload +def plugins(name: Optional[str] = ..., kind: None = ...) -> Iterator[Plugin]: + ... + + +def plugins( + name: Optional[str] = None, kind: Optional[Type[PluginT]] = None +) -> Iterator[Plugin]: """ A generator of the plugins. diff --git a/rdflib/plugins/serializers/jsonld.py b/rdflib/plugins/serializers/jsonld.py index 67f3b86232..f5067e2873 100644 --- a/rdflib/plugins/serializers/jsonld.py +++ b/rdflib/plugins/serializers/jsonld.py @@ -41,6 +41,7 @@ from rdflib.graph import Graph from rdflib.term import URIRef, Literal, BNode from rdflib.namespace import RDF, XSD +from typing import IO, Optional from ..shared.jsonld.context import Context, UNDEF from ..shared.jsonld.util import json @@ -53,10 +54,16 @@ class JsonLDSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(JsonLDSerializer, self).__init__(store) - def serialize(self, stream, base=None, encoding=None, **kwargs): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **kwargs + ): # TODO: docstring w. args and return value encoding = encoding or "utf-8" if encoding not in ("utf-8", "utf-16"): diff --git a/rdflib/plugins/serializers/n3.py b/rdflib/plugins/serializers/n3.py index 6c4e2ec46d..806f445ef8 100644 --- a/rdflib/plugins/serializers/n3.py +++ b/rdflib/plugins/serializers/n3.py @@ -14,7 +14,7 @@ class N3Serializer(TurtleSerializer): short_name = "n3" - def __init__(self, store, parent=None): + def __init__(self, store: Graph, parent=None): super(N3Serializer, self).__init__(store) self.keywords.update({OWL.sameAs: "=", SWAP_LOG.implies: "=>"}) self.parent = parent diff --git a/rdflib/plugins/serializers/nquads.py b/rdflib/plugins/serializers/nquads.py index 54ee42ba12..e76c747d49 100644 --- a/rdflib/plugins/serializers/nquads.py +++ b/rdflib/plugins/serializers/nquads.py @@ -1,5 +1,7 @@ +from typing import IO, Optional import warnings +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.term import Literal from rdflib.serializer import Serializer @@ -9,15 +11,22 @@ class NQuadsSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): if not store.context_aware: raise Exception( "NQuads serialization only makes " "sense for context-aware stores!" ) super(NQuadsSerializer, self).__init__(store) - - def serialize(self, stream, base=None, encoding=None, **args): + self.store: ConjunctiveGraph + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NQuadsSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): diff --git a/rdflib/plugins/serializers/nt.py b/rdflib/plugins/serializers/nt.py index bc265ee5f4..467de46134 100644 --- a/rdflib/plugins/serializers/nt.py +++ b/rdflib/plugins/serializers/nt.py @@ -3,6 +3,9 @@ See for details about the format. """ +from typing import IO, Optional + +from rdflib.graph import Graph from rdflib.term import Literal from rdflib.serializer import Serializer @@ -17,11 +20,17 @@ class NTSerializer(Serializer): Serializes RDF graphs to NTriples format. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) self.encoding = "ascii" # n-triples are ascii encoded - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NTSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): @@ -39,7 +48,7 @@ class NT11Serializer(NTSerializer): Exactly like nt - only utf8 encoded. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) # default to utf-8 diff --git a/rdflib/plugins/serializers/rdfxml.py b/rdflib/plugins/serializers/rdfxml.py index 72648afbac..901d911d91 100644 --- a/rdflib/plugins/serializers/rdfxml.py +++ b/rdflib/plugins/serializers/rdfxml.py @@ -1,9 +1,11 @@ +from typing import IO, Dict, Optional, Set, cast from rdflib.plugins.serializers.xmlwriter import XMLWriter from rdflib.namespace import Namespace, RDF, RDFS # , split_uri from rdflib.plugins.parsers.RDFVOC import RDFVOC -from rdflib.term import URIRef, Literal, BNode +from rdflib.graph import Graph +from rdflib.term import Identifier, URIRef, Literal, BNode from rdflib.util import first, more_than from rdflib.collection import Collection from rdflib.serializer import Serializer @@ -17,7 +19,7 @@ class XMLSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(XMLSerializer, self).__init__(store) def __bindings(self): @@ -39,14 +41,20 @@ def __bindings(self): for prefix, namespace in bindings.items(): yield prefix, namespace - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base elif self.store.base is not None: self.base = self.store.base self.__stream = stream - self.__serialized = {} + self.__serialized: Dict[Identifier, int] = {} encoding = self.encoding self.write = write = lambda uni: stream.write(uni.encode(encoding, "replace")) @@ -154,12 +162,18 @@ def fix(val): class PrettyXMLSerializer(Serializer): - def __init__(self, store, max_depth=3): + def __init__(self, store: Graph, max_depth=3): super(PrettyXMLSerializer, self).__init__(store) - self.forceRDFAbout = set() - - def serialize(self, stream, base=None, encoding=None, **args): - self.__serialized = {} + self.forceRDFAbout: Set[URIRef] = set() + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): + self.__serialized: Dict[Identifier, int] = {} store = self.store # if base is given here, use that, if not and a base is set for the graph use that if base is not None: @@ -190,8 +204,9 @@ def serialize(self, stream, base=None, encoding=None, **args): writer.namespaces(namespaces.items()) + subject: Identifier # Write out subjects that can not be inline - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if (None, None, subject) in store: if (subject, None, subject) in store: self.subject(subject, 1) @@ -202,7 +217,7 @@ def serialize(self, stream, base=None, encoding=None, **args): # write out BNodes last (to ensure they can be inlined where possible) bnodes = set() - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if isinstance(subject, BNode): bnodes.add(subject) continue @@ -217,9 +232,9 @@ def serialize(self, stream, base=None, encoding=None, **args): stream.write("\n".encode("latin-1")) # Set to None so that the memory can get garbage collected. - self.__serialized = None + self.__serialized = None # type: ignore[assignment] - def subject(self, subject, depth=1): + def subject(self, subject: Identifier, depth: int = 1): store = self.store writer = self.writer @@ -227,7 +242,7 @@ def subject(self, subject, depth=1): writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) - self.forceRDFAbout.remove(subject) + self.forceRDFAbout.remove(subject) # type: ignore[arg-type] elif subject not in self.__serialized: self.__serialized[subject] = 1 @@ -264,10 +279,11 @@ def subj_as_obj_more_than(ceil): writer.pop(element) elif subject in self.forceRDFAbout: + # TODO FIXME?: this looks like a duplicate of first condition writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) - self.forceRDFAbout.remove(subject) + self.forceRDFAbout.remove(subject) # type: ignore[arg-type] def predicate(self, predicate, object, depth=1): writer = self.writer diff --git a/rdflib/plugins/serializers/trig.py b/rdflib/plugins/serializers/trig.py index cdaedd4892..5a606e401c 100644 --- a/rdflib/plugins/serializers/trig.py +++ b/rdflib/plugins/serializers/trig.py @@ -4,9 +4,12 @@ """ from collections import defaultdict +from typing import IO, TYPE_CHECKING, Optional, Union +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.plugins.serializers.turtle import TurtleSerializer -from rdflib.term import BNode +from rdflib.term import BNode, Node + __all__ = ["TrigSerializer"] @@ -16,8 +19,11 @@ class TrigSerializer(TurtleSerializer): short_name = "trig" indentString = 4 * " " - def __init__(self, store): + def __init__(self, store: Union[Graph, ConjunctiveGraph]): + self.default_context: Optional[Node] if store.context_aware: + if TYPE_CHECKING: + assert isinstance(store, ConjunctiveGraph) self.contexts = list(store.contexts()) self.default_context = store.default_context.identifier if store.default_context: @@ -48,7 +54,14 @@ def reset(self): super(TrigSerializer, self).reset() self._contexts = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args + ): self.reset() self.stream = stream # if base is given here, use that, if not and a base is set for the graph use that diff --git a/rdflib/plugins/serializers/trix.py b/rdflib/plugins/serializers/trix.py index 05b6f528f3..1612d815cc 100644 --- a/rdflib/plugins/serializers/trix.py +++ b/rdflib/plugins/serializers/trix.py @@ -1,3 +1,4 @@ +from typing import IO, Optional from rdflib.serializer import Serializer from rdflib.plugins.serializers.xmlwriter import XMLWriter @@ -15,14 +16,20 @@ class TriXSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(TriXSerializer, self).__init__(store) if not store.context_aware: raise Exception( "TriX serialization only makes sense for context-aware stores" ) - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): nm = self.store.namespace_manager diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index c87b6ea760..11a0b38165 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,6 +9,7 @@ import codecs import csv +from typing import IO from rdflib import Variable, BNode, URIRef, Literal @@ -61,7 +62,7 @@ def __init__(self, result): if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, @@ -69,15 +70,15 @@ def serialize(self, stream, encoding="utf-8", **kwargs): import codecs - stream = codecs.getwriter(encoding)(stream) + stream = codecs.getwriter(encoding)(stream) # type: ignore[assignment] out = csv.writer(stream, delimiter=self.delim) - vs = [self.serializeTerm(v, encoding) for v in self.result.vars] + vs = [self.serializeTerm(v, encoding) for v in self.result.vars] # type: ignore[union-attr] out.writerow(vs) for row in self.result.bindings: out.writerow( - [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] + [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] # type: ignore[union-attr] ) def serializeTerm(self, term, encoding): diff --git a/rdflib/plugins/sparql/results/jsonresults.py b/rdflib/plugins/sparql/results/jsonresults.py index 13a8da5eff..562f0ec075 100644 --- a/rdflib/plugins/sparql/results/jsonresults.py +++ b/rdflib/plugins/sparql/results/jsonresults.py @@ -1,4 +1,5 @@ import json +from typing import IO, Any, Dict, Optional, TextIO, Union from rdflib.query import Result, ResultException, ResultSerializer, ResultParser from rdflib import Literal, URIRef, BNode, Variable @@ -28,9 +29,9 @@ class JSONResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding=None): + def serialize(self, stream: IO, encoding: str = None): # type: ignore[override] - res = {} + res: Dict[str, Any] = {} if self.result.type == "ASK": res["head"] = {} res["boolean"] = self.result.askAnswer diff --git a/rdflib/plugins/sparql/results/txtresults.py b/rdflib/plugins/sparql/results/txtresults.py index baa5316b48..3f41df9429 100644 --- a/rdflib/plugins/sparql/results/txtresults.py +++ b/rdflib/plugins/sparql/results/txtresults.py @@ -1,8 +1,11 @@ +from typing import IO, List, Optional from rdflib import URIRef, BNode, Literal from rdflib.query import ResultSerializer +from rdflib.namespace import NamespaceManager +from rdflib.term import Variable -def _termString(t, namespace_manager): +def _termString(t, namespace_manager: Optional[NamespaceManager]): if t is None: return "-" if namespace_manager: @@ -21,7 +24,13 @@ class TXTResultSerializer(ResultSerializer): A write only QueryResult serializer for text/ascii tables """ - def serialize(self, stream, encoding, namespace_manager=None): + # TODO FIXME: class specific args should be keyword only. + def serialize( # type: ignore[override] + self, + stream: IO, + encoding: str, + namespace_manager: Optional[NamespaceManager] = None, + ): """ return a text table of query results """ @@ -43,7 +52,7 @@ def c(s, w): return "(no results)\n" else: - keys = self.result.vars + keys: List[Variable] = self.result.vars # type: ignore[assignment] maxlen = [0] * len(keys) b = [ [_termString(r[k], namespace_manager) for k in keys] diff --git a/rdflib/plugins/sparql/results/xmlresults.py b/rdflib/plugins/sparql/results/xmlresults.py index 8c77b50ad1..3869bc9e24 100644 --- a/rdflib/plugins/sparql/results/xmlresults.py +++ b/rdflib/plugins/sparql/results/xmlresults.py @@ -1,4 +1,5 @@ import logging +from typing import IO, Optional from xml.sax.saxutils import XMLGenerator from xml.dom import XML_NAMESPACE @@ -28,15 +29,17 @@ class XMLResultParser(ResultParser): - def parse(self, source, content_type=None): + # TODO FIXME: content_type should be a keyword only arg. + def parse(self, source, content_type: Optional[str] = None): # type: ignore[override] return XMLResult(source) class XMLResult(Result): - def __init__(self, source, content_type=None): + def __init__(self, source, content_type: Optional[str] = None): try: - parser = etree.XMLParser(huge_tree=True) + # try use as if etree is from lxml, and if not use it as normal. + parser = etree.XMLParser(huge_tree=True) # type: ignore[call-arg] tree = etree.parse(source, parser) except TypeError: tree = etree.parse(source) @@ -55,7 +58,7 @@ def __init__(self, source, content_type=None): if type_ == "SELECT": self.bindings = [] - for result in results: + for result in results: # type: ignore[union-attr] r = {} for binding in result: r[Variable(binding.get("name"))] = parseTerm(binding[0]) @@ -69,7 +72,7 @@ def __init__(self, source, content_type=None): ] else: - self.askAnswer = boolean.text.lower().strip() == "true" + self.askAnswer = boolean.text.lower().strip() == "true" # type: ignore[union-attr] def parseTerm(element): @@ -101,7 +104,7 @@ class XMLResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding="utf-8"): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): writer = SPARQLXMLWriter(stream, encoding) if self.result.type == "ASK": diff --git a/rdflib/query.py b/rdflib/query.py index da174cd1f2..65ee141581 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -5,6 +5,7 @@ import warnings import types from typing import Optional, Union, cast +from typing import IO, TYPE_CHECKING, List, Optional, TextIO, Union, cast, overload from io import BytesIO, BufferedIOBase @@ -12,6 +13,10 @@ __all__ = ["Processor", "Result", "ResultParser", "ResultSerializer", "ResultException"] +if TYPE_CHECKING: + from .graph import Graph + from .term import Variable + class Processor(object): """ @@ -161,17 +166,17 @@ class Result(object): """ - def __init__(self, type_): + def __init__(self, type_: str): if type_ not in ("CONSTRUCT", "DESCRIBE", "SELECT", "ASK"): raise ResultException("Unknown Result type: %s" % type_) self.type = type_ - self.vars = None + self.vars: Optional[List[Variable]] = None self._bindings = None self._genbindings = None - self.askAnswer = None - self.graph = None + self.askAnswer: bool = None # type: ignore[assignment] + self.graph: "Graph" = None # type: ignore[assignment] def _get_bindings(self): if self._genbindings: @@ -192,7 +197,12 @@ def _set_bindings(self, b): ) @staticmethod - def parse(source=None, format=None, content_type=None, **kwargs): + def parse( + source=None, + format: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs, + ): from rdflib import plugin if format: @@ -208,7 +218,7 @@ def parse(source=None, format=None, content_type=None, **kwargs): def serialize( self, - destination: Optional[Union[str, BufferedIOBase]] = None, + destination: Optional[Union[str, IO]] = None, encoding: str = "utf-8", format: str = "xml", **args, @@ -230,7 +240,7 @@ def serialize( :return: bytes """ if self.type in ("CONSTRUCT", "DESCRIBE"): - return self.graph.serialize( + return self.graph.serialize( # type: ignore[return-value] destination, encoding=encoding, format=format, **args ) @@ -241,10 +251,10 @@ def serialize( if destination is None: streamb: BytesIO = BytesIO() stream2 = EncodeOnlyUnicode(streamb) - serializer.serialize(stream2, encoding=encoding, **args) + serializer.serialize(stream2, encoding=encoding, **args) # type: ignore return streamb.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, encoding=encoding, **args) else: location = cast(str, destination) @@ -339,9 +349,14 @@ def parse(self, source, **kwargs): class ResultSerializer(object): - def __init__(self, result): + def __init__(self, result: Result): self.result = result - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize( + self, + stream: IO, + encoding: str = "utf-8", + **kwargs, + ): """return a string properly serialized""" pass # abstract diff --git a/rdflib/serializer.py b/rdflib/serializer.py index ecb8da0a2b..16a47d55cd 100644 --- a/rdflib/serializer.py +++ b/rdflib/serializer.py @@ -10,21 +10,31 @@ """ +from typing import IO, TYPE_CHECKING, Optional from rdflib.term import URIRef +if TYPE_CHECKING: + from rdflib.graph import Graph + __all__ = ["Serializer"] -class Serializer(object): - def __init__(self, store): - self.store = store - self.encoding = "UTF-8" - self.base = None +class Serializer: + def __init__(self, store: "Graph"): + self.store: "Graph" = store + self.encoding: str = "UTF-8" + self.base: Optional[str] = None - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ) -> None: """Abstract method""" - def relativize(self, uri): + def relativize(self, uri: str): base = self.base if base is not None and uri.startswith(base): uri = URIRef(uri.replace(base, "", 1)) diff --git a/rdflib/store.py b/rdflib/store.py index a7aa8d0b09..d96f569eaf 100644 --- a/rdflib/store.py +++ b/rdflib/store.py @@ -1,6 +1,11 @@ from io import BytesIO import pickle from rdflib.events import Dispatcher, Event +from typing import Tuple, TYPE_CHECKING, Iterable, Optional + +if TYPE_CHECKING: + from .term import Node + from .graph import Graph """ ============ @@ -172,7 +177,7 @@ def __get_node_pickler(self): def create(self, configuration): self.dispatcher.dispatch(StoreCreatedEvent(configuration=configuration)) - def open(self, configuration, create=False): + def open(self, configuration, create: bool = False): """ Opens the store specified by the configuration string. If create is True a store will be created if it does not already @@ -204,7 +209,12 @@ def gc(self): pass # RDF APIs - def add(self, triple, context, quoted=False): + def add( + self, + triple: Tuple["Node", "Node", "Node"], + context: Optional["Graph"], + quoted: bool = False, + ): """ Adds the given statement to a specific context or to the model. The quoted argument is interpreted by formula-aware stores to indicate @@ -215,7 +225,7 @@ def add(self, triple, context, quoted=False): """ self.dispatcher.dispatch(TripleAddedEvent(triple=triple, context=context)) - def addN(self, quads): + def addN(self, quads: Iterable[Tuple["Node", "Node", "Node", "Graph"]]): """ Adds each item in the list of statements to a specific context. The quoted argument is interpreted by formula-aware stores to indicate this @@ -283,7 +293,11 @@ def triples_choices(self, triple, context=None): for (s1, p1, o1), cg in self.triples((subject, None, object_), context): yield (s1, p1, o1), cg - def triples(self, triple_pattern, context=None): + def triples( + self, + triple_pattern: Tuple[Optional["Node"], Optional["Node"], Optional["Node"]], + context=None, + ): """ A generator over all the triples matching the pattern. Pattern can include any objects for used for comparing against nodes in the store, diff --git a/rdflib/term.py b/rdflib/term.py index eb1f2cb6ca..307a2793ef 100644 --- a/rdflib/term.py +++ b/rdflib/term.py @@ -64,7 +64,7 @@ from urllib.parse import urlparse from decimal import Decimal -from typing import TYPE_CHECKING, Dict, Callable, Union, Type +from typing import TYPE_CHECKING, Dict, Callable, Optional, Union, Type if TYPE_CHECKING: from .paths import AlternativePath, InvPath, NegatedPath, SequencePath, Path @@ -231,10 +231,10 @@ class URIRef(Identifier): __neg__: Callable[["URIRef"], "NegatedPath"] __truediv__: Callable[["URIRef", Union["URIRef", "Path"]], "SequencePath"] - def __new__(cls, value, base=None): + def __new__(cls, value: str, base: Optional[str] = None): if base is not None: ends_in_hash = value.endswith("#") - value = urljoin(base, value, allow_fragments=1) + value = urljoin(base, value, allow_fragments=1) # type: ignore[arg-type] if ends_in_hash: if not value.endswith("#"): value += "#" @@ -248,7 +248,7 @@ def __new__(cls, value, base=None): try: rt = str.__new__(cls, value) except UnicodeDecodeError: - rt = str.__new__(cls, value, "utf-8") + rt = str.__new__(cls, value, "utf-8") # type: ignore[call-overload] return rt def toPython(self):