From 572f63d861fd1766299afd644b43a08d40a22e51 Mon Sep 17 00:00:00 2001 From: Ivan Shapovalov Date: Fri, 8 Mar 2024 02:45:41 +0100 Subject: [PATCH] yq: use tomllib if available, slightly refactor exceptions --- yq/__init__.py | 69 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/yq/__init__.py b/yq/__init__.py index 477dd78..19cd5be 100644 --- a/yq/__init__.py +++ b/yq/__init__.py @@ -13,6 +13,7 @@ import os import subprocess import sys +import typing from datetime import date, datetime, time import argcomplete @@ -28,6 +29,14 @@ __version__ = "0.0.0" +class DecodeError(RuntimeError): + pass + + +class EncodeError(RuntimeError): + pass + + class JSONDateTimeEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, (datetime, date, time)): @@ -42,6 +51,32 @@ def decode_docs(jq_output, json_decoder): yield doc +def decode_toml(fp: typing.IO, *, output_format: str) -> typing.Any: + if output_format != 'toml': + try: + import tomllib + if isinstance(fp, io.StringIO): + # HACK: this is just for tests, which kinda defeats the purpose of the tests, + # but there is no StringIO.buffer so there is no way around it + return tomllib.loads(fp.read()) + return tomllib.load(fp.buffer if isinstance(fp, io.TextIOWrapper) else fp) + except ImportError: + pass + + import tomlkit + return tomlkit.load(fp) + + +def encode_toml(fp: typing.TextIO, docs: typing.Iterable[typing.Any]): + import tomlkit + for doc in docs: + if not isinstance(doc, dict): + raise EncodeError( + "Error converting JSON to TOML: cannot represent non-object types at top level." + ) + tomlkit.dump(doc, fp) + + def xq_cli(): cli(input_format="xml", program_name="xq") @@ -244,7 +279,7 @@ def yq( import xmltodict if xml_item_depth != 0: - raise Exception("xml_item_depth is not supported with xq -x") + raise DecodeError("xml_item_depth is not supported with xq -x") doc = xmltodict.parse( input_stream.buffer if isinstance(input_stream, io.TextIOWrapper) else input_stream.read(), @@ -254,13 +289,12 @@ def yq( json.dump(doc, json_buffer, cls=JSONDateTimeEncoder) json_buffer.write("\n") elif input_format == "toml": - import tomlkit - - doc = tomlkit.load(input_stream) # type: ignore + doc = decode_toml(input_stream, output_format=output_format) json.dump(doc, json_buffer, cls=JSONDateTimeEncoder) json_buffer.write("\n") else: - raise Exception("Unknown input format") + raise DecodeError("Unknown input format") + jq_out, jq_err = jq.communicate(json_buffer.getvalue()) json_decoder = json.JSONDecoder() if output_format == "yaml" or output_format == "annotated_yaml": @@ -286,11 +320,10 @@ def yq( if xml_root: doc = {xml_root: doc} # type: ignore elif not isinstance(doc, dict): - msg = ( - "{}: Error converting JSON to XML: cannot represent non-object types at top level. " + raise EncodeError( + "Error converting JSON to XML: cannot represent non-object types at top level. " "Use --xml-root=name to envelope your output with a root element." ) - exit_func(msg.format(program_name)) full_document = True if xml_dtd else False try: xmltodict.unparse( @@ -298,18 +331,13 @@ def yq( ) except ValueError as e: if "Document must have exactly one root" in str(e): - raise Exception(str(e) + " Use --xml-root=name to envelope your output with a root element") + raise EncodeError(str(e) + " Use --xml-root=name to envelope your output with a root element") else: raise output_stream.write("\n") elif output_format == "toml": - import tomlkit - - for doc in decode_docs(jq_out, json_decoder): - if not isinstance(doc, dict): - msg = "{}: Error converting JSON to TOML: cannot represent non-object types at top level." - exit_func(msg.format(program_name)) - tomlkit.dump(doc, output_stream) + docs = decode_docs(jq_out, json_decoder) + encode_toml(output_stream, docs) else: if input_format == "yaml": loader_class = get_loader( @@ -344,13 +372,12 @@ def emit_entry(path, entry): if doc: emit_entry(None, doc) elif input_format == "toml": - import tomlkit - for input_stream in input_streams: - json.dump(tomlkit.load(input_stream), jq.stdin, cls=JSONDateTimeEncoder) # type: ignore + doc = decode_toml(input_stream, output_format=output_format) + json.dump(doc, jq.stdin, cls=JSONDateTimeEncoder) # type: ignore jq.stdin.write("\n") # type: ignore else: - raise Exception("Unknown input format") + raise DecodeError("Unknown input format") try: jq.stdin.close() # type: ignore @@ -360,5 +387,7 @@ def emit_entry(path, entry): for input_stream in input_streams: input_stream.close() exit_func(jq.returncode) + except (DecodeError, EncodeError) as e: + exit_func("{}: {}".format(program_name, e)) except Exception as e: exit_func("{}: Error running jq: {}: {}.".format(program_name, type(e).__name__, e))