Skip to content

Commit

Permalink
sst_unittest.py: add type annotations to run_sst (#1197)
Browse files Browse the repository at this point in the history
* test_LookupTable.py: remove unused imports

* testsuite_default_UnitAlgebra.py: idiomatic boolean comparison

* sst_unittest.py: missing multiprocessing import

* sst_unittest.py: avoid shadowing variable names

* sst_unittest.py: add type annotations

* sst_unittest.py: enforce run_sst timeout_sec as integer
  • Loading branch information
berquist authored Jan 22, 2025
1 parent 7a2b424 commit 629e9ab
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
39 changes: 26 additions & 13 deletions src/sst/core/testingframework/sst_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import threading
import signal
import time
import multiprocessing
from typing import Optional

import test_engine_globals
Expand Down Expand Up @@ -58,11 +59,10 @@ class SSTTestCase(unittest.TestCase):
def __init__(self, methodName: str) -> None:
# NOTE: __init__ is called at startup for all tests before any
# setUpModules(), setUpClass(), setUp() and the like are called.
super(SSTTestCase, self).__init__(methodName)
super().__init__(methodName)
self.testname = methodName
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore [assignment,type-var]
self._testsuite_dirpath = parent_module_path
#log_forced("SSTTestCase: __init__() - {0}".format(self.testname))
self.initializeClass(self.testname)
self._start_test_time = time.time()
self._stop_test_time = time.time()
Expand Down Expand Up @@ -195,7 +195,7 @@ def get_testsuite_dir(self) -> str:
""" Return the directory path of the testsuite that is being run
Returns:
(str)The path of the testsite directory
(str) The path of the testsite directory
"""
return self._testsuite_dirpath

Expand Down Expand Up @@ -235,9 +235,23 @@ def get_test_runtime_sec(self) -> float:
### Method to run an SST simulation
################################################################################

def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files="",
other_args="", num_ranks=None, num_threads=None, global_args=None,
timeout_sec=120, expected_rc=0, check_sdl_file=True, send_signal=signal.NSIG, signal_sec=3):
def run_sst(
self,
sdl_file: str,
out_file: str,
err_file: Optional[str] = None,
set_cwd: Optional[str] = None,
mpi_out_files: str = "",
other_args: str = "",
num_ranks: Optional[int] = None,
num_threads: Optional[int] = None,
global_args: Optional[str] = None,
timeout_sec: int = 120,
expected_rc: int = 0,
check_sdl_file: bool = True,
send_signal: int = signal.NSIG,
signal_sec: int = 3
) -> str:
""" Launch sst with with the command line and send output to the
output file. The SST execution will be monitored for result errors and
timeouts. On an error or timeout, a SSTTestCase.assert() will be generated
Expand Down Expand Up @@ -288,8 +302,7 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
check_param_type("num_threads", num_threads, int)
if global_args is not None:
check_param_type("global_args", global_args, str)
if not (isinstance(timeout_sec, (int, float)) and not isinstance(timeout_sec, bool)):
raise ValueError("ERROR: Timeout_sec must be a postive int or a float")
check_param_type("timeout_sec", timeout_sec, int)
if expected_rc is not None:
check_param_type("expected_rc", expected_rc, int)

Expand Down Expand Up @@ -331,8 +344,8 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
numa_param = ""
if num_ranks > 1:
# Check to see if mpirun is available
rtn = os.system("which mpirun > /dev/null 2>&1")
if rtn == 0:
rtn_mpirun = os.system("which mpirun > /dev/null 2>&1")
if rtn_mpirun == 0:
mpi_avail = True

numa_param = "-map-by numa:PE={0}".format(num_threads)
Expand Down Expand Up @@ -433,7 +446,7 @@ def tearDownModule() -> None:

###################

def setUpModuleConcurrent(test):
def setUpModuleConcurrent(test: SSTTestCase) -> None:
""" Perform setup functions before the testing Module loads.
This function is called by the Frameworks before tests in any TestCase
Expand Down Expand Up @@ -461,7 +474,7 @@ def setUpModuleConcurrent(test):

###

def tearDownModuleConcurrent(test):
def tearDownModuleConcurrent(test: SSTTestCase) -> None:
""" Perform teardown functions immediately after a testing Module finishes.
This function is called by the Frameworks after all tests in all TestCases
Expand Down
2 changes: 1 addition & 1 deletion tests/test_LookupTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# information, see the LICENSE file in the top level directory of the
# distribution.
import sst
import inspect, os, sys
import inspect

currentframe = inspect.currentframe()
assert currentframe is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/testsuite_default_UnitAlgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def unitalgebra_test_template(self, testtype):

# Perform the test
cmp_result = testing_compare_sorted_diff(testtype, outfile, reffile)
if (cmp_result == False):
if not cmp_result:
diffdata = testing_get_diff_data(testtype)
log_failure(diffdata)
self.assertTrue(cmp_result, "Output/Compare file {0} does not match Reference File {1}".format(outfile, reffile))

0 comments on commit 629e9ab

Please sign in to comment.