diff --git a/testinfra/plugin.py b/testinfra/plugin.py index 707ff505..d1371dbb 100644 --- a/testinfra/plugin.py +++ b/testinfra/plugin.py @@ -15,7 +15,7 @@ import sys import tempfile import time -from typing import AnyStr +from typing import AnyStr, cast import pytest @@ -25,19 +25,19 @@ @pytest.fixture(scope="module") -def _testinfra_host(request): - return request.param +def _testinfra_host(request: pytest.FixtureRequest) -> testinfra.host.Host: + return cast(testinfra.host.Host, request.param) @pytest.fixture(scope="module") -def host(_testinfra_host): +def host(_testinfra_host: testinfra.host.Host) -> testinfra.host.Host: return _testinfra_host host.__doc__ = testinfra.host.Host.__doc__ -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: group = parser.getgroup("testinfra") group.addoption( "--connection", @@ -107,7 +107,7 @@ def pytest_addoption(parser): ) -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if "_testinfra_host" in metafunc.fixturenames: if metafunc.config.option.hosts is not None: hosts = metafunc.config.option.hosts.split(",") @@ -141,7 +141,7 @@ def __init__(self, out): self.total_time = None self.out = out - def pytest_runtest_logreport(self, report): + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: if report.passed: if report.when == "call": # ignore setup/teardown self.passed += 1 @@ -150,7 +150,7 @@ def pytest_runtest_logreport(self, report): elif report.skipped: self.skipped += 1 - def report(self): + def report(self) -> int: if self.failed: status = b"CRITICAL" ret = 2