diff --git a/rdflib/graph.py b/rdflib/graph.py index a6d8aa70ac..666893e449 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1,4 +1,5 @@ from typing import ( + IO, Any, Iterable, Optional, @@ -15,7 +16,7 @@ from rdflib.namespace import Namespace, RDF from rdflib import plugin, exceptions, query, namespace import rdflib.term -from rdflib.term import BNode, Identifier, Node, URIRef, Literal, Genid +from rdflib.term import BNode, Node, URIRef, Literal, Genid from rdflib.paths import Path from rdflib.store import Store from rdflib.serializer import Serializer @@ -31,9 +32,11 @@ import tempfile import pathlib -from io import BytesIO, BufferedIOBase +from io import BufferedIOBase, BytesIO from urllib.parse import urlparse +from _types import BytesIOish + assert Literal # avoid warning assert Namespace # avoid warning @@ -1025,7 +1028,12 @@ def absolutize(self, uri, defrag=1): # no destination and non-None positional encoding @overload def serialize( - self, destination: None, format: str, base: Optional[str], encoding: str, **args + self, + destination: None, + format: str, + base: Optional[str], + encoding: str, + **args, ) -> bytes: ... @@ -1049,7 +1057,7 @@ def serialize( destination: None = ..., format: str = ..., base: Optional[str] = ..., - encoding: None = None, + encoding: None = ..., **args, ) -> str: ... @@ -1070,10 +1078,10 @@ def serialize( @overload def serialize( self, - destination: Optional[Union[str, BufferedIOBase, pathlib.PurePath]] = None, - format: str = "turtle", - base: Optional[str] = None, - encoding: Optional[str] = None, + destination: Optional[Union[str, BufferedIOBase, pathlib.PurePath]] = ..., + format: str = ..., + base: Optional[str] = ..., + encoding: Optional[str] = ..., **args, ) -> Union[bytes, str, "Graph"]: ... @@ -1091,10 +1099,10 @@ def serialize( :param destination: The destination to serialize the graph to. This can be a path as a - :class:`string` or :class:`~pathlib.PurePath` object, or it can be a - :class:`~io.BufferedIOBase` like object. If this parameter is not + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` like object. If this parameter is not supplied the serialized graph will be returned. - :type destination: Optional[Union[str, io.BufferedIOBase, pathlib.PurePath]] + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] :param format: The format that the output should be written in. This value references a :class:`~rdflib.serializer.Serializer` plugin. Format @@ -1124,7 +1132,7 @@ def serialize( base = self.base serializer = plugin.get(format, Serializer)(self) - stream: BufferedIOBase + stream: IO[bytes] if destination is None: stream = BytesIO() if encoding is None: @@ -1134,7 +1142,7 @@ def serialize( serializer.serialize(stream, base=base, encoding=encoding, **args) return stream.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, base=base, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): diff --git a/rdflib/plugins/serializers/turtle.py b/rdflib/plugins/serializers/turtle.py index e594d04f5e..1549f5e048 100644 --- a/rdflib/plugins/serializers/turtle.py +++ b/rdflib/plugins/serializers/turtle.py @@ -10,8 +10,8 @@ from rdflib.exceptions import Error from rdflib.serializer import Serializer from rdflib.namespace import RDF, RDFS -from io import BufferedIOBase, TextIOBase, TextIOWrapper -from typing import Optional +from io import BufferedIOBase, RawIOBase, TextIOBase, TextIOWrapper +from typing import IO, Optional __all__ = ["RecursiveSerializer", "TurtleSerializer"] @@ -45,7 +45,7 @@ class RecursiveSerializer(Serializer): maxDepth = 10 indentString = " " roundtrip_prefixes = () - stream: TextIOBase + stream: IO[str] def __init__(self, store): @@ -170,7 +170,7 @@ def indent(self, modifier=0): return (self.depth + modifier) * self.indentString def write(self, text: str): - """Write text in given encoding.""" + """Write text""" self.stream.write(text) @@ -186,7 +186,6 @@ class TurtleSerializer(RecursiveSerializer): short_name = "turtle" indentString = " " - stream: TextIOBase def __init__(self, store): self._ns_rewrite = {} @@ -236,10 +235,11 @@ def serialize( **args ): self.reset() + if encoding is not None: + self.encoding = encoding self.stream = TextIOWrapper( - stream, encoding, errors="replace", write_through=True + stream, self.encoding, errors="replace", write_through=True ) - # self.encoding = encoding # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base @@ -264,7 +264,6 @@ def serialize( self.write("\n") self.endDocument() - # stream.write("\n".encode(encoding)) self.stream.write("\n") self.base = None self.stream.flush() diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index c87b6ea760..fdd7196377 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,6 +9,8 @@ import codecs import csv +from io import BufferedIOBase, RawIOBase +from typing import IO, Optional, Union, cast from rdflib import Variable, BNode, URIRef, Literal @@ -61,7 +63,12 @@ def __init__(self, result): if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize( + self, + stream: Union[IO[bytes], IO[str]], + encoding: Optional[str] = None, + **kwargs + ): # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, @@ -69,9 +76,10 @@ def serialize(self, stream, encoding="utf-8", **kwargs): import codecs - stream = codecs.getwriter(encoding)(stream) + if isinstance(stream, RawIOBase) or isinstance(stream, BufferedIOBase): # type: ignore[unreachable] + stream = codecs.getwriter(encoding)(stream) # type: ignore[unreachable] - out = csv.writer(stream, delimiter=self.delim) + out = csv.writer(cast(IO[str], stream), delimiter=self.delim) vs = [self.serializeTerm(v, encoding) for v in self.result.vars] out.writerow(vs) diff --git a/rdflib/query.py b/rdflib/query.py index b73a4a7326..4e842c8ce2 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -5,14 +5,15 @@ import warnings import types import pathlib -from typing import Optional, Union, cast, overload +from typing import IO, Optional, Union, cast, overload -from io import BytesIO, BufferedIOBase +from io import BytesIO from urllib.parse import urlparse __all__ = ["Processor", "Result", "ResultParser", "ResultSerializer", "ResultException"] +from _types import BytesIOish class Processor(object): """ @@ -257,7 +258,7 @@ def serialize( @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath], + destination: Optional[Union[str, pathlib.PurePath, BytesIOish]] = ..., encoding: Optional[str] = ..., format: Optional[str] = ..., **args, @@ -268,7 +269,7 @@ def serialize( @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, + destination: Optional[Union[str, pathlib.PurePath, BytesIOish]] = ..., encoding: Optional[str] = None, format: Optional[str] = None, **args, @@ -277,7 +278,7 @@ def serialize( def serialize( self, - destination: Optional[Union[str, BufferedIOBase, pathlib.PurePath]] = None, + destination: Optional[Union[str, pathlib.PurePath, BytesIOish]] = None, encoding: Optional[str] = None, format: Optional[str] = None, **args, @@ -287,10 +288,10 @@ def serialize( :param destination: The destination to serialize the result to. This can be a path as a - :class:`string` or :class:`~pathlib.PurePath` object, or it can be a - :class:`~io.BufferedIOBase` like object. If this parameter is not + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` like object. If this parameter is not supplied the serialized result will be returned. - :type destination: Optional[Union[str, io.BufferedIOBase, pathlib.PurePath]] + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] :param encoding: Encoding of output. :type encoding: Optional[str] :param format: @@ -321,9 +322,9 @@ def serialize( from rdflib import plugin if format is None: - format = "csv" + format = "txt" serializer = plugin.get(format, ResultSerializer)(self) - stream: BufferedIOBase + stream: IO[bytes] if destination is None: stream = BytesIO() if encoding is None: @@ -337,7 +338,7 @@ def serialize( # serializer.serialize(stream2, encoding=encoding, **args) # return streamb.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): @@ -430,15 +431,26 @@ def __init__(self): def parse(self, source, **kwargs): """return a Result object""" - raise NotImplementedError("A ResultParser must implement the parse method") + pass # abstract class ResultSerializer(object): def __init__(self, result): self.result = result - def serialize(self, stream, encoding="utf-8", **kwargs): + @overload + def serialize(self, stream: IO[bytes], encoding: Optional[str] = ..., **kwargs): + ... + + @overload + def serialize(self, stream: IO[str], encoding: None = ..., **kwargs): + ... + + def serialize( + self, + stream: Union[IO[bytes], IO[str]], + encoding: Optional[str] = None, + **kwargs, + ): """return a string properly serialized""" - raise NotImplementedError( - "A ResultSerializer must implement the serialize method" - ) + pass # abstract diff --git a/rdflib/serializer.py b/rdflib/serializer.py index 2d5aa27d44..34b65cb998 100644 --- a/rdflib/serializer.py +++ b/rdflib/serializer.py @@ -10,9 +10,9 @@ """ -from typing import Optional +from typing import IO, BinaryIO, Optional, Union from rdflib.term import URIRef -from io import BufferedIOBase +from io import BufferedIOBase, RawIOBase __all__ = ["Serializer"] @@ -25,13 +25,12 @@ def __init__(self, store): def serialize( self, - stream: BufferedIOBase, + stream: Union[RawIOBase, BufferedIOBase, IO[bytes]], base: Optional[str] = None, encoding: Optional[str] = None, **args ) -> None: """Abstract method""" - raise NotImplementedError("Serializer must implement the serialize method") def relativize(self, uri: str): base = self.base diff --git a/test/test_sparql_serialize.py b/test/test_sparql_serialize.py index d8960e7ff8..aa81560383 100644 --- a/test/test_sparql_serialize.py +++ b/test/test_sparql_serialize.py @@ -6,6 +6,7 @@ import json import io import csv +import inspect EG = Namespace("http://example.com/") @@ -46,40 +47,45 @@ def tearDown(self) -> None: self._tmpdir.cleanup() def test_serialize_table_csv_str(self) -> None: + format = "csv" + def check(data: str) -> None: - self.assertIsInstance(data, str) - data_io = io.StringIO(data) - data_reader = csv.reader(data_io, "unix") - data_rows = list(data_reader) - self.assertEqual(data_rows, self.result_table) - - check(self.result.serialize()) - check(self.result.serialize(None)) - check(self.result.serialize(None, None)) + with self.subTest(caller=inspect.stack()[1]): + self.assertIsInstance(data, str) + data_io = io.StringIO(data) + data_reader = csv.reader(data_io, "unix") + data_rows = list(data_reader) + self.assertEqual(data_rows, self.result_table) + + # check(self.result.serialize()) + # check(self.result.serialize(None)) + # check(self.result.serialize(None, None)) check(self.result.serialize(None, None, None)) - check(self.result.serialize(None, None, "csv")) - check(self.result.serialize(format="csv")) - check(self.result.serialize(destination=None)) - check(self.result.serialize(destination=None, format="csv")) - check(self.result.serialize(destination=None, encoding=None, format="csv")) + check(self.result.serialize(None, None, format)) + check(self.result.serialize(format=format)) + # check(self.result.serialize(destination=None)) + check(self.result.serialize(destination=None, format=format)) + check(self.result.serialize(destination=None, encoding=None, format=format)) def test_serialize_table_csv_bytes(self) -> None: encoding = "utf-8" + format = "csv" def check(data: bytes) -> None: - self.assertIsInstance(data, bytes) - data_str = data.decode(encoding) - data_io = io.StringIO(data_str) - data_reader = csv.reader(data_io, "unix") - data_rows = list(data_reader) - self.assertEqual(data_rows, self.result_table) - - check(self.result.serialize(None, encoding)) - check(self.result.serialize(None, encoding, None)) - check(self.result.serialize(None, encoding, "csv")) - check(self.result.serialize(encoding=encoding, format="csv")) - check(self.result.serialize(destination=None, encoding=encoding)) - check(self.result.serialize(destination=None, encoding=encoding, format="csv")) + with self.subTest(caller=inspect.stack()[1]): + self.assertIsInstance(data, bytes) + data_str = data.decode(encoding) + data_io = io.StringIO(data_str) + data_reader = csv.reader(data_io, "unix") + data_rows = list(data_reader) + self.assertEqual(data_rows, self.result_table) + + # check(self.result.serialize(None, encoding)) + # check(self.result.serialize(None, encoding, None)) + check(self.result.serialize(None, encoding, format)) + check(self.result.serialize(encoding=encoding, format=format)) + # check(self.result.serialize(destination=None, encoding=encoding)) + check(self.result.serialize(destination=None, encoding=encoding, format=format)) def test_serialize_table_csv_file(self) -> None: outfile = self.tmpdir / "output.csv" @@ -87,11 +93,12 @@ def test_serialize_table_csv_file(self) -> None: self.assertFalse(outfile.exists()) def check(none: None) -> None: - self.assertTrue(outfile.exists()) - with outfile.open("r") as file_io: - data_reader = csv.reader(file_io, "unix") - data_rows = list(data_reader) - self.assertEqual(data_rows, self.result_table) + with self.subTest(caller=inspect.stack()[1]): + self.assertTrue(outfile.exists()) + with outfile.open("r") as file_io: + data_reader = csv.reader(file_io, "unix") + data_rows = list(data_reader) + self.assertEqual(data_rows, self.result_table) check(self.result.serialize(outfile)) @@ -101,46 +108,46 @@ def check(none: None) -> None: def test_serialize_table_json(self) -> None: format = "json" - def check(returned: str) -> None: - obj = json.loads(returned) - self.assertEqual( - obj, - { - "head": {"vars": ["subject", "predicate", "object"]}, - "results": { - "bindings": [ - { - "subject": { - "type": "uri", - "value": "http://example.com/e1", - }, - "predicate": { - "type": "uri", - "value": "http://example.com/a1", - }, - "object": { - "type": "uri", - "value": "http://example.com/e2", - }, - }, - { - "subject": { - "type": "uri", - "value": "http://example.com/e1", - }, - "predicate": { - "type": "uri", - "value": "http://example.com/a1", - }, - "object": { - "type": "uri", - "value": "http://example.com/e3", - }, - }, - ] + json_data = { + "head": {"vars": ["subject", "predicate", "object"]}, + "results": { + "bindings": [ + { + "subject": { + "type": "uri", + "value": "http://example.com/e1", + }, + "predicate": { + "type": "uri", + "value": "http://example.com/a1", + }, + "object": { + "type": "uri", + "value": "http://example.com/e2", + }, }, - }, - ) + { + "subject": { + "type": "uri", + "value": "http://example.com/e1", + }, + "predicate": { + "type": "uri", + "value": "http://example.com/a1", + }, + "object": { + "type": "uri", + "value": "http://example.com/e3", + }, + }, + ] + }, + } + + def check(returned: str) -> None: + with self.subTest(caller=inspect.stack()[1]): + obj = json.loads(returned) + self.assertEqual(obj, json_data) check(self.result.serialize(format=format)) check(self.result.serialize(None, format=format))