diff --git a/CHANGELOG.md b/CHANGELOG.md
index b3ed0c6..ac63ed9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,13 @@
- Setter method `TestCase.result` used to ignore values of invalid types. This method now throws a `ValueError` instead.
- Method `xunit2.TestCase.add_rerun_result` has been renamed to `add_interim_result` result to better reflect class hierarchy
of interim (rerun and flaky) results.
+- Methods `JUnitXml.fromfile`, `JUnitXml.fromstring`, `JUnitXml.fromroot` always return a `JUnitXml` instance.
+ Earlier versions return a `TestSuite` instance when the root of the file / string / element is a ``.
+ A `JUnitXml` instance has already been returned by earlier versions when the root of the file / string / element is a ``.
+
+ If you want to create a `TestSuite` instance from a `` element, use
+
+ TestSuite.fromelem(elem)
## [3.1.2] - 2024-08-31
### Fixed
diff --git a/junitparser/junitparser.py b/junitparser/junitparser.py
index a91c6a0..a3a68f3 100644
--- a/junitparser/junitparser.py
+++ b/junitparser/junitparser.py
@@ -743,25 +743,26 @@ def update_statistics(self):
self.time = round(time, 3)
@classmethod
- def fromroot(cls, root_elem: Element):
+ def fromroot(cls, root_elem: Element) -> "JUnitXml":
"""Construct JUnit objects from an elementTree root element."""
- if root_elem.tag == "testsuites":
- instance = cls()
- elif root_elem.tag == "testsuite":
- instance = cls.testsuite()
- else:
+ instance = cls()
+ if root_elem.tag == "testsuite":
+ testsuite_element = root_elem
+ root_elem = testsuite_element.makeelement("testsuites", {})
+ root_elem.append(testsuite_element)
+ if not root_elem.tag == "testsuites":
raise JUnitXmlError("Invalid format.")
instance._elem = root_elem
return instance
@classmethod
- def fromstring(cls, text: Union[str, bytes]):
+ def fromstring(cls, text: Union[str, bytes]) -> "JUnitXml":
"""Construct JUnit objects from an XML string (str or bytes)."""
root_elem = etree.fromstring(text) # nosec
return cls.fromroot(root_elem)
@classmethod
- def fromfile(cls, file: Union[str, IO], parse_func=None):
+ def fromfile(cls, file: Union[str, IO], parse_func=None) -> "JUnitXml":
"""
Construct JUnit objects from an XML file.
diff --git a/tests/test_fromfile.py b/tests/test_fromfile.py
index 31bcbee..605c39a 100644
--- a/tests/test_fromfile.py
+++ b/tests/test_fromfile.py
@@ -5,6 +5,7 @@
from unittest import skipIf
from junitparser import (
TestCase,
+ TestSuite,
Skipped,
Failure,
JUnitXmlError,
@@ -21,13 +22,19 @@
def do_test_fromfile(fromfile_arg):
xml = JUnitXml.fromfile(fromfile_arg)
+ assert isinstance(xml, JUnitXml)
suite1, suite2 = list(iter(xml))
+ assert isinstance(suite1, TestSuite)
+ assert isinstance(suite2, TestSuite)
assert len(list(suite1.properties())) == 0
assert len(list(suite2.properties())) == 3
assert len(suite2) == 3
assert suite2.name == "JUnitXmlReporter.constructor"
assert suite2.tests == 3
cases = list(suite2.iterchildren(TestCase))
+ assert isinstance(cases[0], TestCase)
+ assert isinstance(cases[1], TestCase)
+ assert isinstance(cases[2], TestCase)
assert isinstance(cases[0].result[0], Failure)
assert isinstance(cases[1].result[0], Skipped)
assert len(cases[2].result) == 0
@@ -98,13 +105,19 @@ def parse_func(file_path):
os.path.join(os.path.dirname(__file__), "data/normal.xml"),
parse_func=parse_func,
)
+ assert isinstance(xml, JUnitXml)
suite1, suite2 = list(iter(xml))
+ assert isinstance(suite1, TestSuite)
+ assert isinstance(suite2, TestSuite)
assert len(list(suite1.properties())) == 0
assert len(list(suite2.properties())) == 3
assert len(suite2) == 3
assert suite2.name == "JUnitXmlReporter.constructor"
assert suite2.tests == 3
cases = list(suite2.iterchildren(TestCase))
+ assert isinstance(cases[0], TestCase)
+ assert isinstance(cases[1], TestCase)
+ assert isinstance(cases[2], TestCase)
assert isinstance(cases[0].result[0], Failure)
assert isinstance(cases[1].result[0], Skipped)
assert len(cases[2].result) == 0
@@ -114,20 +127,31 @@ def test_fromfile_without_testsuites_tag():
xml = JUnitXml.fromfile(
os.path.join(os.path.dirname(__file__), "data/no_suites_tag.xml")
)
- cases = list(iter(xml))
- properties = list(iter(xml.properties()))
- assert len(properties) == 3
+ assert isinstance(xml, JUnitXml)
+ suites = list(iter(xml))
+ assert len(suites) == 1
+ suite = suites[0]
+ assert isinstance(suite, TestSuite)
+ assert suite.name == "JUnitXmlReporter.constructor"
+ assert suite.tests == 3
+ cases = list(iter(suite))
assert len(cases) == 3
- assert xml.name == "JUnitXmlReporter.constructor"
- assert xml.tests == 3
+ assert isinstance(cases[0], TestCase)
+ assert isinstance(cases[1], TestCase)
+ assert isinstance(cases[2], TestCase)
assert isinstance(cases[0].result[0], Failure)
assert isinstance(cases[1].result[0], Skipped)
assert len(cases[2].result) == 0
+ properties = list(iter(suite.properties()))
+ assert len(properties) == 3
def test_fromfile_with_testsuite_in_testsuite():
xml = JUnitXml.fromfile(os.path.join(os.path.dirname(__file__), "data/jenkins.xml"))
+ assert isinstance(xml, JUnitXml)
suite1, suite2 = list(iter(xml))
+ assert isinstance(suite1, TestSuite)
+ assert isinstance(suite2, TestSuite)
assert len(list(suite1.properties())) == 0
assert len(list(suite2.properties())) == 3
assert len(suite2) == 3
@@ -135,8 +159,11 @@ def test_fromfile_with_testsuite_in_testsuite():
assert suite2.tests == 3
direct_cases = list(suite2.iterchildren(TestCase))
assert len(direct_cases) == 1
+ assert isinstance(direct_cases[0], TestCase)
assert isinstance(direct_cases[0].result[0], Failure)
all_cases = list(suite2)
+ assert isinstance(all_cases[0], TestCase)
+ assert isinstance(all_cases[1], TestCase)
assert isinstance(all_cases[0].result[0], Failure)
assert isinstance(all_cases[1].result[0], Skipped)
assert len(all_cases[2].result) == 0
@@ -167,6 +194,9 @@ def test_multi_results_in_case():
"""
xml = JUnitXml.fromstring(text)
+ assert isinstance(xml, JUnitXml)
suite = next(iter(xml))
+ assert isinstance(suite, TestSuite)
case = next(iter(suite))
+ assert isinstance(case, TestCase)
assert len(case.result) == 2
diff --git a/tests/test_general.py b/tests/test_general.py
index d5ea93e..6f650d3 100644
--- a/tests/test_general.py
+++ b/tests/test_general.py
@@ -127,6 +127,7 @@ def test_fromstring(self):
"""
result = JUnitXml.fromstring(text)
+ assert isinstance(result, JUnitXml)
assert len(result) == 2
assert result.time == 0
@@ -135,6 +136,7 @@ def test_fromstring_no_testsuites(self):
"""
result = JUnitXml.fromstring(text)
+ assert isinstance(result, JUnitXml)
assert len(result) == 1
assert result.time == 0
@@ -178,10 +180,14 @@ def test_fromroot_testsuite(self):
"""
root_elemt = etree.fromstring(text)
result = JUnitXml.fromroot(root_elemt)
- assert isinstance(result, TestSuite)
- assert result.errors == 1
- assert result.skipped == 1
- cases = list(iter(result))
+ assert isinstance(result, JUnitXml)
+ suites = list(iter(result))
+ assert len(suites) == 1
+ suite = suites[0]
+ assert isinstance(suite, TestSuite)
+ assert suite.errors == 1
+ assert suite.skipped == 1
+ cases = list(iter(suite))
assert len(cases[0].result) == 0
assert len(cases[1].result) == 2
text = cases[1].result[1].text
@@ -203,6 +209,7 @@ def test_fromroot_testsuites(self):
"""
root_elemt = etree.fromstring(text)
result = JUnitXml.fromroot(root_elemt)
+ assert isinstance(result, JUnitXml)
assert result.errors == 1
assert result.skipped == 1
suite = list(iter(result))[0]
diff --git a/tests/test_write.py b/tests/test_write.py
index add40a5..ab55a25 100644
--- a/tests/test_write.py
+++ b/tests/test_write.py
@@ -22,7 +22,7 @@
python_minor = int(sys.version.split(".")[1])
-def get_expected_xml(test_case_name: str, test_suites: bool = True):
+def get_expected_xml(test_case_name: str, test_suites: bool = True, newlines: bool = False):
if python_major == 3 and python_minor <= 7 and not has_lxml:
expected_test_suite = ''
else:
@@ -42,9 +42,13 @@ def get_expected_xml(test_case_name: str, test_suites: bool = True):
start_test_suites = ""
end_test_suites = ""
+ eol = "\n" if newlines else ""
+ indent = " " if newlines else ""
return (
f"\n"
- f'{start_test_suites}{expected_test_suite}{end_test_suites}'
+ f'{start_test_suites}{expected_test_suite}{eol}'
+ f'{indent}{eol}'
+ f'{end_test_suites}'
)
@@ -174,6 +178,31 @@ def test_write_nonascii():
assert xmlfile.getvalue().decode("utf-8") == get_expected_xml("用例1")
+def test_write_no_testsuites():
+ # Has to be a binary string to include xml declarations.
+ text = b"""
+
+
+"""
+ xml = JUnitXml.fromstring(text)
+ assert isinstance(xml, JUnitXml)
+ suite = next(iter(xml))
+ assert isinstance(suite, TestSuite)
+ case = next(iter(suite))
+ assert isinstance(case, TestCase)
+ assert len(case.result) == 0
+
+ # writing this JUnitXml object contains a root element
+ xmlfile = BytesIO()
+ xml.write(xmlfile)
+ assert xmlfile.getvalue().decode("utf-8") == get_expected_xml("case1", test_suites=True, newlines=True)
+
+ # writing the inner testsuite reproduces the input string
+ xmlfile = BytesIO()
+ suite.write(xmlfile)
+ assert xmlfile.getvalue().decode("utf-8") == get_expected_xml("case1", test_suites=False, newlines=True)
+
+
def test_read_written_xml():
suite1 = TestSuite()
suite1.name = "suite1"
diff --git a/tests/test_xunit2.py b/tests/test_xunit2.py
index 6499a52..3621bbc 100644
--- a/tests/test_xunit2.py
+++ b/tests/test_xunit2.py
@@ -211,7 +211,11 @@ def test_suite_fromstring(self):
"""
- suite = JUnitXml.fromstring(text)
+ xml = JUnitXml.fromstring(text)
+ assert isinstance(xml, JUnitXml)
+ suites = list(xml)
+ assert len(suites) == 1
+ suite = suites[0]
assert isinstance(suite, TestSuite)
assert suite.name == "suite name"
cases = list(suite)