Skip to content

Commit

Permalink
test_engine_junit.py: add type annotations (#1194)
Browse files Browse the repository at this point in the history
* test_engine_junit.py: add type annotations

* test_engine_junit.py: remove obsolete class inheritance from object

* test_engine_junit.py: don't use case keyword as internal variable

* test_engine_junit.py: don't mix different types in a single variable
  • Loading branch information
berquist authored Jan 15, 2025
1 parent 5e06091 commit 6556841
Showing 1 changed file with 114 additions and 63 deletions.
177 changes: 114 additions & 63 deletions src/sst/core/testingframework/test_engine_junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
import re
import xml.etree.ElementTree as ET
import xml.dom.minidom
from typing import IO, List, Mapping, Optional
from typing import DefaultDict, Dict, IO, List, Mapping, Optional, Union

Entry = Dict[str, Optional[str]]

################################################################################

Expand Down Expand Up @@ -88,15 +90,27 @@

################################################################################

class JUnitTestSuite(object):
class JUnitTestSuite:
"""
Suite of test cases.
Can handle unicode strings or binary strings if their encoding is provided.
"""

def __init__(self, name, test_cases=None, hostname=None, id=None,
package=None, timestamp=None, properties=None, file=None,
log=None, url=None, stdout=None, stderr=None):
def __init__(
self,
name: str,
test_cases: Optional[List["JUnitTestCase"]] = None,
hostname: Optional[str] = None,
id: Optional[str] = None,
package: Optional[str] = None,
timestamp: Optional[str] = None,
properties: Optional[Mapping[str, str]] = None,
file: Optional[str] = None,
log: Optional[str] = None,
url: Optional[str] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
) -> None:
self.name = name
if not test_cases:
test_cases = []
Expand All @@ -118,7 +132,7 @@ def __init__(self, name, test_cases=None, hostname=None, id=None,

####

def junit_build_xml_doc(self, encoding=None):
def junit_build_xml_doc(self, encoding: Optional[str] = None) -> ET.Element:
"""
Builds the XML document for the JUnit test suite.
Produces clean unicode strings and decodes non-unicode with the help of encoding.
Expand Down Expand Up @@ -184,35 +198,35 @@ def junit_build_xml_doc(self, encoding=None):
stderr_element.text = _junit_decode(self.stderr, encoding)

# test cases
for case in self.test_cases:
for test_case in self.test_cases:
test_case_attributes = dict()
test_case_attributes["name"] = _junit_decode(case.name, encoding)
if case.assertions:
test_case_attributes["name"] = _junit_decode(test_case.name, encoding)
if test_case.assertions:
# Number of assertions in the test case
test_case_attributes["assertions"] = "%d" % case.assertions
if case.elapsed_sec:
test_case_attributes["time"] = "%f" % case.elapsed_sec
if case.timestamp:
test_case_attributes["timestamp"] = _junit_decode(case.timestamp, encoding)
if case.classname:
test_case_attributes["classname"] = _junit_decode(case.classname, encoding)
if case.status:
test_case_attributes["status"] = _junit_decode(case.status, encoding)
if case.category:
test_case_attributes["class"] = _junit_decode(case.category, encoding)
if case.file:
test_case_attributes["file"] = _junit_decode(case.file, encoding)
if case.line:
test_case_attributes["line"] = _junit_decode(case.line, encoding)
if case.log:
test_case_attributes["log"] = _junit_decode(case.log, encoding)
if case.url:
test_case_attributes["url"] = _junit_decode(case.url, encoding)
test_case_attributes["assertions"] = "%d" % test_case.assertions # type: ignore [str-format]
if test_case.elapsed_sec:
test_case_attributes["time"] = "%f" % test_case.elapsed_sec
if test_case.timestamp:
test_case_attributes["timestamp"] = _junit_decode(test_case.timestamp, encoding)
if test_case.classname:
test_case_attributes["classname"] = _junit_decode(test_case.classname, encoding)
if test_case.status:
test_case_attributes["status"] = _junit_decode(test_case.status, encoding)
if test_case.category:
test_case_attributes["class"] = _junit_decode(test_case.category, encoding)
if test_case.file:
test_case_attributes["file"] = _junit_decode(test_case.file, encoding)
if test_case.line:
test_case_attributes["line"] = _junit_decode(test_case.line, encoding)
if test_case.log:
test_case_attributes["log"] = _junit_decode(test_case.log, encoding)
if test_case.url:
test_case_attributes["url"] = _junit_decode(test_case.url, encoding)

test_case_element = ET.SubElement(xml_element, "testcase", test_case_attributes)

# failures
for failure in case.failures:
for failure in test_case.failures:
if failure["output"] or failure["message"]:
attrs = {"type": "failure"}
if failure["message"]:
Expand All @@ -225,7 +239,7 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(failure_element)

# errors
for error in case.errors:
for error in test_case.errors:
if error["message"] or error["output"]:
attrs = {"type": "error"}
if error["message"]:
Expand All @@ -238,7 +252,7 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(error_element)

# skippeds
for skipped in case.skipped:
for skipped in test_case.skipped:
attrs = {"type": "skipped"}
if skipped["message"]:
attrs["message"] = _junit_decode(skipped["message"], encoding)
Expand All @@ -248,28 +262,41 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(skipped_element)

# test stdout
if case.stdout:
if test_case.stdout:
stdout_element = ET.Element("system-out")
stdout_element.text = _junit_decode(case.stdout, encoding)
stdout_element.text = _junit_decode(test_case.stdout, encoding)
test_case_element.append(stdout_element)

# test stderr
if case.stderr:
if test_case.stderr:
stderr_element = ET.Element("system-err")
stderr_element.text = _junit_decode(case.stderr, encoding)
stderr_element.text = _junit_decode(test_case.stderr, encoding)
test_case_element.append(stderr_element)

return xml_element

####

class JUnitTestCase(object):
class JUnitTestCase:
"""A JUnit test case with a result and possibly some stdout or stderr"""

def __init__(self, name, classname=None, elapsed_sec=None, stdout=None,
stderr=None, assertions=None, timestamp=None, status=None,
category=None, file=None, line=None, log=None, url=None,
allow_multiple_subelements=False):
def __init__(
self,
name: str,
classname: Optional[str] = None,
elapsed_sec: Optional[float] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
assertions: Optional[str] = None,
timestamp: Optional[str] = None,
status: Optional[str] = None,
category: Optional[str] = None,
file: Optional[str] = None,
line: Optional[str] = None,
log: Optional[str] = None,
url: Optional[str] = None,
allow_multiple_subelements: bool = False,
) -> None:
self.name = name
self.assertions = assertions
self.elapsed_sec = elapsed_sec
Expand All @@ -284,18 +311,23 @@ def __init__(self, name, classname=None, elapsed_sec=None, stdout=None,
self.stdout = stdout
self.stderr = stderr
self.is_enabled = True
self.errors = []
self.failures = []
self.skipped = []
self.allow_multiple_subalements = allow_multiple_subelements

def junit_add_error_info(self, message=None, output=None, error_type=None):
self.errors: List[Entry] = []
self.failures: List[Entry] = []
self.skipped: List[Entry] = []
self.allow_multiple_subelements = allow_multiple_subelements

def junit_add_error_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
error_type: Optional[str] = None,
) -> None:
"""Adds an error message, output, or both to the test case"""
error = {}
error["message"] = message
error["output"] = output
error["type"] = error_type
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.errors.append(error)
elif not len(self.errors):
Expand All @@ -308,13 +340,18 @@ def junit_add_error_info(self, message=None, output=None, error_type=None):
if error_type:
self.errors[0]["type"] = error_type

def junit_add_failure_info(self, message=None, output=None, failure_type=None):
def junit_add_failure_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
failure_type: Optional[str] = None,
) -> None:
"""Adds a failure message, output, or both to the test case"""
failure = {}
failure["message"] = message
failure["output"] = output
failure["type"] = failure_type
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.failures.append(failure)
elif not len(self.failures):
Expand All @@ -327,12 +364,16 @@ def junit_add_failure_info(self, message=None, output=None, failure_type=None):
if failure_type:
self.failures[0]["type"] = failure_type

def junit_add_skipped_info(self, message=None, output=None):
def junit_add_skipped_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
) -> None:
"""Adds a skipped message, output, or both to the test case"""
skipped = {}
skipped["message"] = message
skipped["output"] = output
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.skipped.append(skipped)
elif not len(self.skipped):
Expand All @@ -343,25 +384,29 @@ def junit_add_skipped_info(self, message=None, output=None):
if output:
self.skipped[0]["output"] = output

def junit_add_elapsed_sec(self, elapsed_sec):
def junit_add_elapsed_sec(self, elapsed_sec: float) -> None:
"""Add the elapsed time to the testcase"""
self.elapsed_sec = elapsed_sec

def junit_is_failure(self):
def junit_is_failure(self) -> bool:
"""returns true if this test case is a failure"""
return sum(1 for f in self.failures if f["message"] or f["output"]) > 0

def junit_is_error(self):
def junit_is_error(self) -> bool:
"""returns true if this test case is an error"""
return sum(1 for e in self.errors if e["message"] or e["output"]) > 0

def junit_is_skipped(self):
def junit_is_skipped(self) -> bool:
"""returns true if this test case has been skipped"""
return len(self.skipped) > 0

####

def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
def junit_to_xml_report_string(
test_suites: List["JUnitTestSuite"],
prettyprint: bool = True,
encoding: Optional[str] = None,
) -> str:
"""
Returns the string representation of the JUnit XML document.
@param encoding: The encoding of the input.
Expand All @@ -374,7 +419,7 @@ def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
raise TypeError("test_suites must be a list of test suites")

xml_element = ET.Element("testsuites")
attributes = defaultdict(int)
attributes: DefaultDict[str, Union[int, float]] = defaultdict(int)
for ts in test_suites:
ts_xml = ts.junit_build_xml_doc(encoding=encoding)
for key in ["disabled", "errors", "failures", "tests"]:
Expand All @@ -396,18 +441,24 @@ def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
if prettyprint:
# minidom.parseString() works just on correctly encoded binary strings
xml_string = xml_string.encode(encoding or "utf-8")
xml_string = xml.dom.minidom.parseString(xml_string)
xml_string_document = xml.dom.minidom.parseString(xml_string)
# toprettyxml() produces unicode if no encoding is being passed
# or binary string with an encoding
xml_string = xml_string.toprettyxml(encoding=encoding)
if encoding:
xml_string = xml_string_document.toprettyxml(encoding=encoding)
if isinstance(xml_string, bytes):
assert encoding is not None
xml_string = xml_string.decode(encoding)
# is unicode now
return xml_string

####

def junit_to_xml_report_file(file_descriptor, test_suites, prettyprint=True, encoding=None):
def junit_to_xml_report_file(
file_descriptor: IO[str],
test_suites: List["JUnitTestSuite"],
prettyprint: bool = True,
encoding: Optional[str] = None,
) -> None:
"""
Writes the JUnit XML document to a file.
"""
Expand All @@ -417,15 +468,15 @@ def junit_to_xml_report_file(file_descriptor, test_suites, prettyprint=True, enc

####

def _junit_decode(var, encoding):
def _junit_decode(var: Optional[str], encoding: Optional[str]) -> str:
"""
If not already unicode, decode it.
"""
return str(var)

####

def _junit_clean_illegal_xml_chars(string_to_clean):
def _junit_clean_illegal_xml_chars(string_to_clean: str) -> str:
"""
Removes any illegal unicode characters from the given XML string.
@see: http://stackoverflow.com/questions/1707890/fast-way-to-filter-
Expand Down

0 comments on commit 6556841

Please sign in to comment.