diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b18b5f4c..f07ab743 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -86,6 +86,8 @@ jobs: echo "${{ secrets.CMOD_LOGIN }}" \ | tee ~/logbook.sybase_login \ | sha256sum + echo DISPY_TOKAMAK=CMOD \ + | tee -a "$GITHUB_ENV" - name: Setup DIII-D if: ${{ matrix.tokamak == 'DIII-D' }} @@ -107,7 +109,7 @@ jobs: | sha256sum echo -e "[FreeTDS]\nDescription = FreeTDS\nDriver = $TDS" \ | sudo tee -a /etc/odbcinst.ini - echo DIIID_TEST=1 \ + echo DISPY_TOKAMAK=D3D \ | tee -a "$GITHUB_ENV" - name: Setup Python @@ -139,12 +141,8 @@ jobs: - name: Test EFIT run: python examples/efit.py - - name: Test quick - run: pytest -v tests/test_quick.py - - - name: Test features - if: ${{ matrix.tokamak == 'C-MOD' }} - run: pytest -v --durations=0 tests/test_cmod_features.py + - name: Run all tests + run: pytest -v --durations=0 tests - name: Close tunnel run: xargs -a ssh.pid kill -TERM diff --git a/disruption_py/cli/evaluate_methods.py b/disruption_py/cli/evaluate_methods.py index 82a38b19..0745d4d8 100644 --- a/disruption_py/cli/evaluate_methods.py +++ b/disruption_py/cli/evaluate_methods.py @@ -1,141 +1,33 @@ import argparse from contextlib import contextmanager -from typing import Dict, List import numpy as np -import pandas as pd -from pandas.api.types import is_numeric_dtype import logging -from disruption_py.handlers.cmod_handler import CModHandler -from disruption_py.settings.log_settings import LogSettings from disruption_py.settings.shot_ids_request import ShotIdsRequestParams, shot_ids_request_runner -from disruption_py.settings.shot_settings import ShotSettings -from disruption_py.utils.constants import TIME_CONST +from disruption_py.utils.eval.eval_against_sql import eval_against_sql, get_failure_statistics_string from disruption_py.utils.mappings.mappings_helpers import map_string_to_enum from disruption_py.utils.mappings.tokamak import Tokamak -from disruption_py.utils.mappings.tokamak_helpers import get_tokamak_from_environment -from disruption_py.utils.math_utils import matlab_gradient_1d_vectorized +from disruption_py.utils.mappings.tokamak_helpers import get_tokamak_from_environment, get_tokamak_test_expected_failure_columns, get_tokamak_handler, get_tokamak_test_shot_ids -CMOD_TEST_SHOTS = [ - 1150805012, # Flattop Disruption - 1150805013, # No Disruption - 1150805014, # No Disruption - 1150805015, # Rampdown Disruption - 1150805016, # Rampdown Disruption - 1150805017, # Rampdown Disruption - 1150805019, # Rampdown Disruption - 1150805020, # Rampdown Disruption - 1150805021, # Rampdown Disruption - 1150805022 # Flattop Disruption -] -TIME_EPSILON = 0.05 # Tolerance for taking the difference between two times [s] -IP_EPSILON = 1e5 # Tolerance for taking the difference between two ip values [A] - -VAL_TOLERANCE = 0.01 # Tolerance for comparing values between MDSplus and SQL -MATCH_FRACTION = 0.95 # Fraction of signals that must match between MDSplus and SQL - -def get_mdsplus_data(handler, shot_list): - shot_settings = ShotSettings( - efit_tree_name="efit18", - set_times_request="efit", - log_settings=LogSettings( - console_log_level=logging.ERROR - ) - ) - return handler.get_shots_data( - shot_ids_request=shot_list, - shot_settings=shot_settings, - output_type_request="dict", - ) - -def get_sql_data(handler, mdsplus_data : Dict, shot_list): - shot_data = {} - for shot_id in shot_list: - times = mdsplus_data[shot_id]['time'] - sql_data = handler.database.get_shots_data([shot_id]) - shot_data[shot_id] = pd.merge_asof(times.to_frame(), sql_data, on='time', direction='nearest', tolerance=TIME_CONST) - return shot_data - - -def test_data_match(sql_shot_df : pd.DataFrame, mdsplus_shot_df : pd.DataFrame, data_column : str): - - sql_column_data = sql_shot_df[data_column].astype(np.float64) - mds_column_data = mdsplus_shot_df[data_column].astype(np.float64) - - # copare data for numeric differences - relative_difference = np.where( - sql_column_data != 0, - np.abs((mds_column_data - sql_column_data) / sql_column_data), - np.where(mds_column_data != 0, np.inf, np.nan) - ) - numeric_anomalies_mask = np.greater(relative_difference, VAL_TOLERANCE) - - # compare data for nan differences - sql_is_nan_ = pd.isnull(sql_column_data) - mdsplus_is_nan = pd.isnull(mds_column_data) - nan_anomalies_mask = (sql_is_nan_ != mdsplus_is_nan) - - anomalies = np.argwhere(numeric_anomalies_mask | nan_anomalies_mask) - - return not (len(anomalies) / len(relative_difference) > 1 - MATCH_FRACTION) - -def evaluate_cmod_accuracy(shot_list : List = None): - """ - Evaluate the accuracy of CMod methods. - - Prints a short report on the methods that have suceeded and failed. - Success criteria is having more that 95% of results within 1% of known results. - """ - print("Evaluating accuracy of CMod methods...") - if shot_list is None or len(shot_list) == 0: - shot_list = CMOD_TEST_SHOTS +def evaluate_accuracy(tokamak : Tokamak, shot_ids : list[int], fail_quick : bool = False, data_columns : list[str] = None): + handler = get_tokamak_handler(tokamak) + if shot_ids is None or len(shot_ids) == 0: + shot_ids = get_tokamak_test_shot_ids(tokamak) else: - shot_list = [int(shot_id) for shot_id in shot_list] - - @contextmanager - def monkey_patch_numpy_gradient(): - original_function = np.gradient - np.gradient = matlab_gradient_1d_vectorized - try: - yield - finally: - np.gradient = original_function - - with monkey_patch_numpy_gradient(): - return _evaluate_cmod_accuracy(shot_list) - -def _evaluate_cmod_accuracy(shot_list : List = None): - cmod_handler = CModHandler() - print("Getting data from MDSplus") - mdsplus_data = get_mdsplus_data(cmod_handler, shot_list) - print("Getting data from sql table") - sql_data = get_sql_data(cmod_handler, mdsplus_data, shot_list) - - success_columns = set() - unknown_columns = set() - failure_columns = set() - - for shot_id in shot_list: - - mdsplus_shot_df : pd.DataFrame = mdsplus_data[shot_id] - sql_shot_df : pd.DataFrame = sql_data[shot_id] + shot_ids = [int(shot_id) for shot_id in shot_ids] - for data_column in mdsplus_shot_df.columns: - if data_column not in sql_shot_df.columns or not is_numeric_dtype(mdsplus_shot_df[data_column]): - unknown_columns.add(data_column) - continue - - if test_data_match(sql_shot_df, mdsplus_shot_df, data_column): - success_columns.add(data_column) - else: - failure_columns.add(data_column) + expected_failure_columns = get_tokamak_test_expected_failure_columns(tokamak) + + data_differences = eval_against_sql( + handler=handler, + shot_ids=shot_ids, + expected_failure_columns=expected_failure_columns, + fail_quick=fail_quick, + test_columns=data_columns + ) - success_columns = success_columns.difference(failure_columns) - unknown_columns = unknown_columns.difference(success_columns).difference(failure_columns) - print(f"Successful Columns (failure criteria not met for any shot): {success_columns}") - print(f"Columns with a failure: {failure_columns}") - print(f"Columns that lacked testing data: {unknown_columns}") - return success_columns, unknown_columns, failure_columns + print(get_failure_statistics_string(data_differences)) + def main(args): """ @@ -183,13 +75,15 @@ def main(args): ) all_shot_ids = shot_ids_request_runner(args.shotlist, shot_ids_request_params) - if tokamak == Tokamak.CMOD: - evaluate_cmod_accuracy(all_shot_ids) - else: - print("Sorry, this tokamak is not currently supported.") - + data_columns = [args.data_column] if args.data_column else None + + print("Running evaluation...") + evaluate_accuracy(tokamak=tokamak, shot_ids=all_shot_ids, fail_quick=args.fail_quick, data_columns=data_columns) + def get_parser(): parser = argparse.ArgumentParser(description='Evaluate the accuracy of DisruptionPy methods on a Tokamak.') parser.add_argument('--shotlist', type=str, help='Path to file specifying a shotlist, leave blank for interactive mode', default=None) + parser.add_argument('--fail_quick', action='store_true', help='Fail quickly', default=False) + parser.add_argument('--data_column', type=str, help='Data column to test', default=None) return parser \ No newline at end of file diff --git a/disruption_py/handlers/__init__.py b/disruption_py/handlers/__init__.py index 4148e204..93b3ff8b 100644 --- a/disruption_py/handlers/__init__.py +++ b/disruption_py/handlers/__init__.py @@ -1 +1 @@ -from .cmod_handler import CModHandler \ No newline at end of file +from .handler import Handler \ No newline at end of file diff --git a/disruption_py/mdsplus_integration/mds_connection.py b/disruption_py/mdsplus_integration/mds_connection.py index 603d783d..2bc1a81a 100644 --- a/disruption_py/mdsplus_integration/mds_connection.py +++ b/disruption_py/mdsplus_integration/mds_connection.py @@ -3,6 +3,8 @@ import numpy as np import MDSplus +from disruption_py.utils.utils import safe_cast + class ProcessMDSConnection(): """ Abstract class for connecting to MDSplus. @@ -142,7 +144,7 @@ def get_data( data = self.conn.get("_sig=" + path, arguments).data() if astype: - data = data.astype(astype, copy=False) + data = safe_cast(data, astype) return data @@ -185,9 +187,9 @@ def get_data_with_dims( dims = [self.conn.get(f"dim_of(_sig,{dim_num})").data() for dim_num in dim_nums] if astype: - data = data.astype(astype, copy=False) + data = safe_cast(data, astype) if cast_all: - dims = [dim.astype(astype, copy=False) for dim in dims] + dims = [safe_cast(dim, astype) for dim in dims] return data, *dims @@ -226,7 +228,7 @@ def get_dims( dims = [self.conn.get(f"dim_of({path},{d})").data() for d in dim_nums] if astype: - dims = [dim.astype(astype, copy=False) for dim in dims] + dims = [safe_cast(dim, astype) for dim in dims] return dims diff --git a/disruption_py/settings/output_type_request.py b/disruption_py/settings/output_type_request.py index 043dc63c..f1c8ffcf 100644 --- a/disruption_py/settings/output_type_request.py +++ b/disruption_py/settings/output_type_request.py @@ -203,7 +203,7 @@ def get_results(self, params: FinishOutputTypeRequestParams): return self.results def stream_output_cleanup(self, params: FinishOutputTypeRequestParams): - self.results = [] + self.results = {} class DataFrameOutputRequest(OutputTypeRequest): """ @@ -213,7 +213,8 @@ def __init__(self): self.results : pd.DataFrame = pd.DataFrame() def _output_shot(self, params : ResultOutputTypeRequestParams): - self.results = pd.concat([self.results, params.result], ignore_index=True) + if not params.result.empty and not params.result.isna().all().all(): + self.results = pd.concat([self.results, params.result], ignore_index=True) def get_results(self, params: FinishOutputTypeRequestParams): return self.results diff --git a/disruption_py/settings/set_times_request.py b/disruption_py/settings/set_times_request.py index 0ba6f088..22ee5384 100644 --- a/disruption_py/settings/set_times_request.py +++ b/disruption_py/settings/set_times_request.py @@ -272,6 +272,7 @@ def _get_times(self, params : SetTimesRequestParams) -> np.ndarray: _set_times_request_mappings: Dict[str, SetTimesRequest] = { "efit" : EfitSetTimesRequest(), "disruption" : DisruptionSetTimesRequest(), + "disruption_warning": {Tokamak.CMOD: EfitSetTimesRequest(), Tokamak.D3D: DisruptionSetTimesRequest()}, "ip" : IpSetTimesRequest(), } # --8<-- [end:set_times_request_dict] diff --git a/disruption_py/settings/shot_settings.py b/disruption_py/settings/shot_settings.py index ee6f65f9..5dbe09f2 100644 --- a/disruption_py/settings/shot_settings.py +++ b/disruption_py/settings/shot_settings.py @@ -53,7 +53,7 @@ class or in an included shot_data_request. All methods with at least one include set_times_request : SetTimesRequest The set times request to be used when setting the timebase for the shot. The retrieved data will be interpolated to this timebase. Can pass any SetTimesRequestType that resolves to a SetTimesRequest. - See SetTimesRequest for more details. Defaults to "efit". + See SetTimesRequest for more details. Defaults to "disruption_warning". signal_domain : SignalDomain The domain of the timebase that should be used when retrieving data for the shot. Either "full", "flattop", or "rampup_and_flattop". Can pass either a SignalDomain or the associated string. Defaults @@ -91,7 +91,7 @@ class or in an included shot_data_request. All methods with at least one include shot_data_requests : List[ShotDataRequest] = field(default_factory=list) # Timebase setting - set_times_request : SetTimesRequest = "efit" + set_times_request : SetTimesRequest = "disruption_warning" signal_domain : SignalDomain = "full" use_existing_data_timebase : bool = False interpolation_method : InterpolationMethod = "linear" diff --git a/disruption_py/shots/parameter_methods/cmod/basic_parameter_methods.py b/disruption_py/shots/parameter_methods/cmod/basic_parameter_methods.py index 48193130..a91b931f 100644 --- a/disruption_py/shots/parameter_methods/cmod/basic_parameter_methods.py +++ b/disruption_py/shots/parameter_methods/cmod/basic_parameter_methods.py @@ -6,7 +6,7 @@ from disruption_py.settings.shot_data_request import ShotDataRequest, ShotDataRequestParams from disruption_py.utils.mappings.tokamak import Tokamak from disruption_py.utils.math_utils import gaussian_fit, interp1, smooth -from disruption_py.utils.utils import without_duplicates +from disruption_py.utils.utils import safe_cast, without_duplicates from disruption_py.shots.helpers.method_caching import cached_method, parameter_cached_method try: from MDSplus import mdsExceptions @@ -1434,7 +1434,7 @@ def efit_rz2psi(params : ShotDataRequestParams, r, z, t, tree='analysis'): r = r.flatten() z = z.flatten() psi = np.full((len(r), len(t)), np.nan) - z = z.astype('float32') # TODO: Ask if this change is necessary + z = safe_cast(z, 'float32') # TODO: Ask if this change is necessary psirz, rgrid, zgrid, times = params.mds_conn.get_data_with_dims(r'\efit_geqdsk:psirz', tree_name=tree, dim_nums=[0, 1, 2]) rgrid, zgrid = np.meshgrid(rgrid, zgrid) #, indexing='ij') diff --git a/disruption_py/shots/parameter_methods/d3d/basic_parameter_methods.py b/disruption_py/shots/parameter_methods/d3d/basic_parameter_methods.py index 08a87dc2..f8cd7bec 100644 --- a/disruption_py/shots/parameter_methods/d3d/basic_parameter_methods.py +++ b/disruption_py/shots/parameter_methods/d3d/basic_parameter_methods.py @@ -988,8 +988,8 @@ def _get_ne_te(params : ShotDataRequestParams, data_source="blessed", ts_systems # Place NaNs for broken channels lasers[laser]['te'][lasers[laser]['te'] == 0] = np.nan lasers[laser]['ne'][np.where(lasers[laser]['ne'] == 0)] = np.nan - params.logger.debug("_get_ne_te: Core bins", lasers['core']['te'].shape) - params.logger.debug("_get_ne_te: Tangential bins", lasers['tangential']['te'].shape) + params.logger.debug("_get_ne_te: Core bins {}".format(lasers['core']['te'].shape)) + params.logger.debug("_get_ne_te: Tangential bins {}".format(lasers['tangential']['te'].shape)) # If both systems/lasers available, combine them and interpolate the data # from the tangential system onto the finer (core) timebase if 'tangential' in lasers and lasers['tangential'] is not None: diff --git a/disruption_py/utils/constants.py b/disruption_py/utils/constants.py index f9364009..69cecc2e 100644 --- a/disruption_py/utils/constants.py +++ b/disruption_py/utils/constants.py @@ -13,6 +13,10 @@ MAX_SHOT_TIME = 7.0 # [s] <-- used to detect if shot times are using ms +# Used for testing +VAL_TOLERANCE = 0.01 # Tolerance for comparing values between MDSplus and SQL +MATCH_FRACTION = 0.95 # Fraction of signals that must match between MDSplus and SQL +VERBOSE_OUTPUT = False DEFAULT_COLS = ['time', 'time_until_disrupt','shot'] PAPER_COLS = [ @@ -32,3 +36,67 @@ 'ip-exp-10-none', 'ip-exp-50-none', ] + + +TEST_SHOTS = { + "cmod": { + "flattop1_fast": 1150805012, + "no_disrup1_full": 1150805013, + "no_disrup2_full": 1150805014, + "rampdown1_full": 1150805015, + "rampdown2_full": 1150805016, + "rampdown3_full": 1150805017, + "rampdown4_full": 1150805019, + "rampdown5_full": 1150805020, + "rampdown6_full": 1150805021, + "flattop2_full": 1150805022, + }, + "d3d": { + "disrup1_fast": 161228, + "disrup2_full": 161237, + "no_disrup1_full": 166177, + "no_disrup2_full": 166253, + } +} + +TEST_COLUMNS = { + "cmod": [ + 'I_efc', 'sxr', 'time_until_disrupt', 'beta_n', 'beta_p', 'kappa', 'li', + 'upper_gap', 'lower_gap', 'q0', 'qstar', 'q95', 'v_loop_efit', 'Wmhd', + 'ssep', 'n_over_ncrit', 'tritop', 'tribot', 'a_minor', 'rmagx', 'chisq', + 'dbetap_dt', 'dli_dt', 'dWmhd_dt', 'V_surf', 'kappa_area', 'Te_width', + 'ne_peaking', 'Te_peaking', 'pressure_peaking', 'n_e', 'dn_dt', + 'Greenwald_fraction', 'n_equal_1_mode', 'n_equal_1_normalized', + 'n_equal_1_phase', 'BT', 'prad_peaking', 'v_0', 'ip', 'dip_dt', + 'dip_smoothed', 'ip_prog', 'dipprog_dt', 'ip_error', 'z_error', + 'z_prog', 'zcur', 'v_z', 'z_times_v_z', 'p_oh', 'v_loop', 'p_rad', + 'dprad_dt', 'p_lh', 'p_icrf', 'p_input', 'radiated_fraction', 'time', + 'shot', 'commit_hash' + ], + "d3d": [ + "H98", "ip", "q95", "squareness", "zcur_normalized", "q0", "ip_error", "beta_p", + "time_until_disrupt", "z_error", "li_RT", "beta_p_RT", "n1rms_normalized", "qstar", + "ip_error_RT", "H_alpha", "n_e_RT", "Wmhd_RT", "ip_RT", "dli_dt", "dbetap_dt", + "dn_dt", "shot", "n_equal_1_mode", "dip_dt", "upper_gap", "n_equal_1_normalized", + "q95_RT", "zcur", "lower_gap", "Greenwald_fraction_RT", "kappa", "kappa_area", + "power_supply_railed", "n_e", "delta", "Greenwald_fraction", "dWmhd_dt", "Wmhd", + "aminor", "time", "li", "beta_n", "dipprog_dt", "dipprog_dt_RT" + ] +} + +EXPECTED_FAILURE_COLUMNS = { + "cmod": [ + 'Te_width', 'z_error', 'z_prog', 'zcur', 'v_z', 'z_times_v_z', + 'dipprog_dt', 'ip_error', 'sxr', 'tritop', 'tribot', 'a_minor', + 'rmagx', 'chisq', 'V_surf', 'ne_peaking', 'Te_peaking', + 'pressure_peaking', 'Greenwald_fraction', 'n_equal_1_phase', + 'BT', 'prad_peaking', 'dip_smoothed', 'ip_prog', 'p_input' + ], + "d3d": [ + 'kappa', 'H_alpha', 'dipprog_dt_RT', 'li', 'dWmhd_dt', 'beta_p', 'dn_dt', + 'Greenwald_fraction', 'li_RT', 'n_equal_1_normalized', 'zcur', 'dli_dt', + 'Greenwald_fraction_RT', 'H98', 'q95_RT', 'zcur_normalized', 'qstar', + 'Wmhd', 'lower_gap', 'beta_p_RT', 'n1rms_normalized', 'dbetap_dt', + 'n_equal_1_mode', 'q95', 'upper_gap', 'q0', 'n_e', 'beta_n', 'kappa_area' + ] +} \ No newline at end of file diff --git a/disruption_py/utils/eval/data_difference.py b/disruption_py/utils/eval/data_difference.py new file mode 100644 index 00000000..a5091c91 --- /dev/null +++ b/disruption_py/utils/eval/data_difference.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass, field +from disruption_py.utils.constants import MATCH_FRACTION, VAL_TOLERANCE, VERBOSE_OUTPUT +import numpy as np +import pandas as pd + +from disruption_py.utils.utils import safe_cast + +@dataclass +class DataDifference: + """ + Data difference between MDSplus and SQL. + """ + shot_id : int + data_column : str + + missing_sql_data : bool + missing_mdsplus_data : bool + + anomalies : np.ndarray = field(init=False) # 1 if anomaly, 0 o.w. + relative_difference : np.ndarray = field(init=False) + mdsplus_column_data : pd.Series + sql_column_data : pd.Series + expect_failure : bool + + def __post_init__(self): + self.anomalies, self.relative_difference = self.compute_numeric_anomalies() + + @property + def num_anomalies(self) -> int: + return np.sum(self.anomalies) + + @property + def timebase_length(self) -> int: + return len(self.anomalies) + + @property + def missing_data(self) -> bool: + return self.missing_sql_data or self.missing_mdsplus_data + + @property + def failed(self) -> str: + if self.missing_data: + return True + return self.num_anomalies / self.timebase_length > 1 - MATCH_FRACTION + + @property + def failure_ratio_string(self) -> str: + return f"{self.num_anomalies / self.timebase_length:.4f}" + + @property + def column_mismatch_string(self) -> str: + return f"Shot {self.shot_id} column {self.data_column} with arrays:\n{self.difference_df.to_string()}" + + @property + def difference_df(self) -> pd.DataFrame: + indexes = np.arange(self.timebase_length) if VERBOSE_OUTPUT else self.anomalies.flatten() + anomaly = self.anomalies[indexes] + return pd.DataFrame({ + 'MDSplus Data': self.mdsplus_column_data.iloc[indexes], + 'Reference Data (SQL)': self.sql_column_data.iloc[indexes], + 'Relative difference': self.relative_difference[indexes], + 'Anomaly': anomaly + }) + + def compute_numeric_anomalies(self): + """ + Get the indices of the data where numeric differences exist between the sql and mdsplus data. + """ + + # handle missing data case + if self.missing_mdsplus_data or self.missing_sql_data: + if self.missing_mdsplus_data and self.missing_sql_data: + missing_timebase_length = 0 + elif self.missing_mdsplus_data: + missing_timebase_length = len(self.sql_column_data) + elif self.missing_sql_data: + missing_timebase_length = len(self.mdsplus_column_data) + return np.ones(missing_timebase_length, dtype=bool), np.zeros(missing_timebase_length) + + + sql_is_nan = pd.isnull(self.sql_column_data) + mdsplus_is_nan = pd.isnull(self.mdsplus_column_data) + + # handle case where both arrays are all null + if sql_is_nan.all() and mdsplus_is_nan.all(): + return np.zeros(len(self.mdsplus_column_data), dtype=bool), np.zeros(len(self.mdsplus_column_data)) + + relative_difference = safe_cast(np.where( + self.sql_column_data != 0, + np.abs((self.mdsplus_column_data - self.sql_column_data) / self.sql_column_data), + np.where(self.mdsplus_column_data != 0, np.inf, np.nan), + ), 'float64') # necessary in case all produced values are nan + + numeric_anomalies_mask = np.where(np.isnan(relative_difference), False, relative_difference > VAL_TOLERANCE) + nan_anomalies_mask = (sql_is_nan != mdsplus_is_nan) + anomalies : pd.Series = numeric_anomalies_mask | nan_anomalies_mask + + return anomalies.to_numpy(), relative_difference \ No newline at end of file diff --git a/disruption_py/utils/eval/eval_against_sql.py b/disruption_py/utils/eval/eval_against_sql.py new file mode 100644 index 00000000..710e5f71 --- /dev/null +++ b/disruption_py/utils/eval/eval_against_sql.py @@ -0,0 +1,237 @@ +from contextlib import contextmanager +import inspect + +import numpy as np +from disruption_py.handlers import Handler +from disruption_py.settings import LogSettings, ShotSettings +from disruption_py.utils.math_utils import matlab_gradient_1d_vectorized + +import pandas as pd + + +import logging +from typing import Callable, Dict, List + +from disruption_py.utils.constants import TIME_CONST +from disruption_py.utils.eval.data_difference import DataDifference + + +def get_mdsplus_data(handler : Handler, shot_ids : List[int]) -> Dict[int, pd.DataFrame]: + """ + Get MDSplus data for a list of shots. + + Returns + ------- + Dict[int, pd.DataFrame] + Dictionary mapping shot IDs to retrieved MDSplus data. + """ + shot_settings = ShotSettings( + efit_tree_name="efit18", + set_times_request="disruption_warning", + log_settings=LogSettings( + log_to_console=False, + log_file_path="tests/cmod.log", + log_file_write_mode="w", + file_log_level=logging.DEBUG + ) + ) + shot_data = handler.get_shots_data( + shot_ids_request=shot_ids, + shot_settings=shot_settings, + output_type_request="dict", + ) + return shot_data + + +def get_sql_data_for_mdsplus(handler : Handler, shot_ids : List[int], mdsplus_data : Dict[int, pd.DataFrame]) -> Dict[int, pd.DataFrame]: + """ + Get SQL data for a list of shots and map onto the timebase of the supplied MDSplus data. + + Returns + ------- + Dict[int, pd.DataFrame] + Dictionary mapping shot IDs to retrieved SQL data. + """ + shot_data = {} + for shot_id in shot_ids: + times = mdsplus_data[shot_id]['time'] + sql_data = handler.database.get_shots_data([shot_id]) + shot_data[shot_id] = pd.merge_asof(times.to_frame(), sql_data, on='time', direction='nearest', tolerance=TIME_CONST) + return shot_data + +def eval_shots_against_sql( + shot_ids : List[int], + mdsplus_data : Dict[int, pd.DataFrame], + sql_data : Dict[int, pd.DataFrame], + data_columns : List[str], + fail_quick : bool = False, + expected_failure_columns : List[str] = None, +) -> List["DataDifference"]: + """ + Test if the difference between the two data is within tolerance. + """ + if expected_failure_columns is None: + expected_failure_columns = [] + + data_differences : List[DataDifference] = [] + for data_column in data_columns: + for shot_id in shot_ids: + mdsplus_shot_data, sql_shot_data = mdsplus_data[shot_id], sql_data[shot_id] + expect_failure = data_column in expected_failure_columns + + data_difference = eval_shot_against_sql( + shot_id=shot_id, + mdsplus_shot_data=mdsplus_shot_data, + sql_shot_data=sql_shot_data, + data_column=data_column, + fail_quick = fail_quick, + expect_failure = expect_failure + ) + data_differences.append(data_difference) + return data_differences + +def eval_shot_against_sql( + shot_id : int, + mdsplus_shot_data : pd.DataFrame, + sql_shot_data : pd.DataFrame, + data_column : str, + fail_quick : bool = False, + expect_failure : bool = False, +) -> "DataDifference": + """ + Test if the difference between the two data is within tolerance. + """ + missing_mdsplus_data = (data_column not in mdsplus_shot_data) + missing_sql_data = (data_column not in sql_shot_data) + data_difference = DataDifference( + shot_id=shot_id, + data_column=data_column, + mdsplus_column_data=mdsplus_shot_data.get(data_column, None), + sql_column_data=sql_shot_data.get(data_column, None), + missing_mdsplus_data=missing_mdsplus_data, + missing_sql_data=missing_sql_data, + expect_failure=expect_failure, + ) + + if fail_quick and not (missing_mdsplus_data or missing_sql_data): + if expect_failure: + assert data_difference.failed, "Expected failure but succeeded:\n{}".format(data_difference.column_mismatch_string) + else: + assert not data_difference.failed, "Expected success but failed:\n{}".format(data_difference.column_mismatch_string) + + return data_difference + +def get_failure_statistics_string(data_differences : list["DataDifference"], data_column=None): + data_difference_by_column = {} + for data_difference in data_differences: + data_difference_by_column.setdefault(data_difference.data_column, []).append(data_difference) + + failure_strings = {} + failed_columns, succeeded_columns, missing_data_columns = set(), set(), set() + matches_expected_failures_columns, not_matches_expected_failures_columns = set(), set() + for ratio_data_column, column_data_differences in data_difference_by_column.items(): + failures = [data_difference.shot_id for data_difference in column_data_differences if data_difference.failed] + failed = len(failures) > 0 + + all_missing_data = all([data_difference.missing_data for data_difference in column_data_differences]) + + anomaly_count = sum([data_difference.num_anomalies for data_difference in column_data_differences]) + timebase_count = sum([data_difference.timebase_length for data_difference in column_data_differences]) + + expect_failure = any([data_difference.expect_failure for data_difference in column_data_differences]) + matches_expected_failure = expect_failure == failed + + # failure string + failure_string_lines = [ + f"Column {ratio_data_column} {'FAILED' if failed else 'SUCCEEDED'}", + f"Matches expected failures: {matches_expected_failure}", + f"Total Entry Failure Rate: {anomaly_count / timebase_count * 100:.2f}%", + ] + failure_string = "\n".join(failure_string_lines) + + # condition string + conditions : Dict[str, Callable[[DataDifference], bool]] = { + "Shots expected to fail that failed": lambda data_difference: data_difference.expect_failure and data_difference.failed, + "Shots expected to succeed that failed": lambda data_difference: not data_difference.expect_failure and data_difference.failed, + "Shots expected to fail that succeeded": lambda data_difference: data_difference.expect_failure and not data_difference.failed, + "Shots expected to succeed that succeeded": lambda data_difference: not data_difference.expect_failure and not data_difference.failed, + "Shots missing sql data": lambda data_difference: data_difference.missing_sql_data, + "Shots missing mdsplus data": lambda data_difference: data_difference.missing_mdsplus_data, + } + condition_results = {} + for condition_name, condition in conditions.items(): + shot_ids = [data_difference.shot_id for data_difference in column_data_differences if condition(data_difference)] + if len(shot_ids) > 0: + condition_results[condition_name] = shot_ids + condition_string = "\n".join([f"{condition_name} ({len(condition_result)} shots): {condition_result}" for condition_name, condition_result in condition_results.items()]) + + # combine the string parts together + failure_strings[ratio_data_column] = failure_string + "\n" + condition_string + + if all_missing_data: + missing_data_columns.add(ratio_data_column) + elif failed: + failed_columns.add(ratio_data_column) + else: + succeeded_columns.add(ratio_data_column) + + if matches_expected_failure: + matches_expected_failures_columns.add(ratio_data_column) + else: + not_matches_expected_failures_columns.add(ratio_data_column) + + if data_column is not None: + return failure_strings.get(data_column, "") + else: + summary_string = f"""\ + ___________________________________________________________________________________________________ + SUMMARY + Columns with a failure: + {"None" if len(failed_columns) == 0 else ""}{", ".join(failed_columns)} + + Columns without a failure: + {"None" if len(succeeded_columns) == 0 else ""}{", ".join(succeeded_columns)} + + Columns lacking data for comparison from sql or mdsplus sources: + {"None" if len(missing_data_columns) == 0 else ""}{", ".join(missing_data_columns)} + ___________________________________________________________________________________________________ + + Columns that match expected failures: + {"None" if len(matches_expected_failures_columns) == 0 else ""}{", ".join(matches_expected_failures_columns)} + + Columns that do not match expected failures: + {"None" if len(not_matches_expected_failures_columns) == 0 else ""}{", ".join(not_matches_expected_failures_columns)} + """ + return '\n\n'.join(failure_strings.values()) + '\n\n' + inspect.cleandoc(summary_string) + +def eval_against_sql(handler : Handler, shot_ids : List[int], expected_failure_columns : List[str], fail_quick : bool, test_columns = None,) -> Dict[int, pd.DataFrame]: + @contextmanager + def monkey_patch_numpy_gradient(): + original_function = np.gradient + np.gradient = matlab_gradient_1d_vectorized + try: + yield + finally: + np.gradient = original_function + + with monkey_patch_numpy_gradient(): + mdsplus_data = get_mdsplus_data(handler, shot_ids) + sql_data = get_sql_data_for_mdsplus(handler, shot_ids, mdsplus_data) + + if test_columns is None: + mdsplus_columns = set().union(*(df.columns for df in mdsplus_data.values())) + sql_columns = set().union(*(df.columns for df in sql_data.values())) + test_columns = sorted(mdsplus_columns.intersection(sql_columns)) + + data_differences = eval_shots_against_sql( + shot_ids=shot_ids, + mdsplus_data=mdsplus_data, + sql_data=sql_data, + data_columns=test_columns, + fail_quick=fail_quick, + expected_failure_columns=expected_failure_columns, + ) + + return data_differences + + diff --git a/disruption_py/utils/mappings/tokamak_helpers.py b/disruption_py/utils/mappings/tokamak_helpers.py index 92a8507e..745b7029 100644 --- a/disruption_py/utils/mappings/tokamak_helpers.py +++ b/disruption_py/utils/mappings/tokamak_helpers.py @@ -1,8 +1,9 @@ import os from disruption_py.databases import D3DDatabase, CModDatabase from disruption_py.utils.mappings.tokamak import Tokamak - -DATABASE_HANDLERS = {Tokamak.D3D: D3DDatabase, Tokamak.CMOD: CModDatabase, Tokamak.EAST: None} +from disruption_py.handlers.cmod_handler import CModHandler +from disruption_py.handlers.d3d_handler import D3DHandler +from disruption_py.utils.constants import EXPECTED_FAILURE_COLUMNS, TEST_COLUMNS, TEST_SHOTS def get_tokamak_from_shot_id(shot_id): if isinstance(shot_id, str): @@ -23,11 +24,46 @@ def get_tokamak_from_shot_id(shot_id): f"Unable to handle shot_id of length {shot_len}") def get_tokamak_from_environment(): - if os.environ.get('CMOD_MONITOR') is not None: + if "DISPY_TOKAMAK" in os.environ: + return Tokamak[os.environ["DISPY_TOKAMAK"]] + if os.path.exists("/usr/local/mfe/disruptions"): return Tokamak.CMOD - else: - return None + if os.path.exists("/fusion/projects/disruption_warning"): + return Tokamak.D3D + return None def get_database_for_shot_id(shot_id : int): tokamak = get_tokamak_from_shot_id(shot_id) - return DATABASE_HANDLERS.get(tokamak, None) + return get_tokamak_database(tokamak) + +def get_tokamak_handler(tokamak : Tokamak): + if tokamak is Tokamak.CMOD: + return CModHandler() + elif tokamak is Tokamak.D3D: + return D3DHandler() + else: + raise ValueError("Tokamak {} not supported for this test".format(tokamak)) + +def get_tokamak_database(tokamak : Tokamak): + if tokamak == Tokamak.CMOD: + return CModDatabase.default() + elif tokamak == Tokamak.D3D: + return D3DDatabase.default() + else: + raise ValueError("Tokamak {} not supported for this test".format(tokamak)) + +def get_tokamak_test_expected_failure_columns(tokamak : Tokamak): + return EXPECTED_FAILURE_COLUMNS.get(tokamak.value) + + +def get_tokamak_test_shot_ids(tokamak : Tokamak) -> list[int]: + shot_id_dict = TEST_SHOTS.get(tokamak.value) + + if "GITHUB_ACTIONS" in os.environ: + shot_id_dict = {key: value for key, value in shot_id_dict.items() if "_fast" in key} + + return list(shot_id_dict.values()) + + +def get_tokamak_test_columns(tokamak : Tokamak): + return TEST_COLUMNS.get(tokamak.value) diff --git a/disruption_py/utils/utils.py b/disruption_py/utils/utils.py index 5d53b58f..d1bd1816 100644 --- a/disruption_py/utils/utils.py +++ b/disruption_py/utils/utils.py @@ -1,4 +1,7 @@ from typing import Callable, List +import warnings + +import numpy as np def instantiate_classes(l : List): """ @@ -32,4 +35,9 @@ def without_duplicates(l : List): The list l with duplicates removed. """ seen = set() - return [x for x in l if not (x in seen or seen.add(x))] \ No newline at end of file + return [x for x in l if not (x in seen or seen.add(x))] + +def safe_cast(array : np.ndarray, dtype, copy = False): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=RuntimeWarning) + return array.astype(dtype, copy = copy) \ No newline at end of file diff --git a/examples/efit.py b/examples/efit.py index 0ab9fa08..4bec4677 100644 --- a/examples/efit.py +++ b/examples/efit.py @@ -4,28 +4,30 @@ execute a simple workflow to fetch EFIT parameters. """ -import os -from disruption_py.handlers.cmod_handler import CModHandler -from disruption_py.handlers.d3d_handler import D3DHandler from disruption_py.settings import ShotSettings, LogSettings +from disruption_py.utils.mappings.tokamak import Tokamak +from disruption_py.utils.mappings.tokamak_helpers import ( + get_tokamak_from_environment, + get_tokamak_handler, +) +tokamak = get_tokamak_from_environment() +handler = get_tokamak_handler(tokamak) -if os.getenv("DIIID_TEST", False) or os.path.exists("/fusion/projects/disruption_warning"): - handler = D3DHandler() +if tokamak is Tokamak.D3D: shot_ids_request = [161228] - set_times_request = "disruption" run_methods = ["_get_efit_parameters"] shape = (247, 16) -else: - handler = CModHandler() +elif tokamak is Tokamak.CMOD: shot_ids_request = [1150805012] - set_times_request = "efit" run_methods = ["_get_EFIT_parameters"] shape = (62, 25) +else: + raise ValueError(f"Unspecified or unsupported tokamak: {tokamak}.") + print(f"Initialized handler: {handler.get_tokamak().value}") shot_settings = ShotSettings( - set_times_request=set_times_request, log_settings=LogSettings(console_log_level=0), run_tags=[], run_methods=run_methods, diff --git a/examples/mdsplus.py b/examples/mdsplus.py index a3140aac..19253741 100644 --- a/examples/mdsplus.py +++ b/examples/mdsplus.py @@ -4,18 +4,24 @@ execute a simple fetch to test MDSplus connection. """ -import os -from disruption_py.handlers.cmod_handler import CModHandler -from disruption_py.handlers.d3d_handler import D3DHandler +from disruption_py.utils.mappings.tokamak import Tokamak +from disruption_py.utils.mappings.tokamak_helpers import ( + get_tokamak_from_environment, + get_tokamak_handler, +) -if os.getenv("DIIID_TEST", False) or os.path.exists("/fusion/projects/disruption_warning"): - handler = D3DHandler() +tokamak = get_tokamak_from_environment() +handler = get_tokamak_handler(tokamak) + +if tokamak is Tokamak.D3D: shot = 161228 shape = (196,) -else: - handler = CModHandler() +elif tokamak is Tokamak.CMOD: shot = 1150805012 shape = (2400,) +else: + raise ValueError(f"Unspecified or unsupported tokamak: {tokamak}.") + mds = handler.mds_connection.conn print(f"Initialized MDSplus: {mds.hostspec}") diff --git a/examples/quick.py b/examples/quick.py deleted file mode 100644 index 38446b5e..00000000 --- a/examples/quick.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 - -""" -quick test: fetch efit parameters from established shot, and check array shape. -""" - -from disruption_py.handlers import CModHandler -from disruption_py.settings import ShotSettings, LogSettings - -handler = CModHandler() - -shot_settings = ShotSettings( - log_settings=LogSettings(console_log_level=0), - run_tags=[], - run_methods=["_get_EFIT_parameters"], -) - -result = handler.get_shots_data( - shot_ids_request=[1150805012], - shot_settings=shot_settings, - output_type_request="dataframe", -) - -print(result) - -assert result.shape == (62, 25) diff --git a/examples/sql.py b/examples/sql.py index 10b6482b..221204fe 100644 --- a/examples/sql.py +++ b/examples/sql.py @@ -5,7 +5,11 @@ """ import os -from disruption_py.databases import CModDatabase, D3DDatabase +from disruption_py.utils.mappings.tokamak import Tokamak +from disruption_py.utils.mappings.tokamak_helpers import ( + get_tokamak_from_environment, + get_tokamak_database, +) queries = [ "select count(distinct shot) from disruption_warning", @@ -15,13 +19,16 @@ + " where shot in (select shot from disruptions)", "select count(distinct shot) from disruptions", ] +tokamak = get_tokamak_from_environment() +db = get_tokamak_database(tokamak) -if os.getenv("DIIID_TEST", False) or os.path.exists("/fusion/projects/disruption_warning"): - db = D3DDatabase.default() +if tokamak is Tokamak.D3D: vals = [13245, 8055, 5190, 24219] -else: - db = CModDatabase.default() +elif tokamak is Tokamak.CMOD: vals = [10435, 6640, 3795, 13785] +else: + raise ValueError(f"Unspecified or unsupported tokamak: {tokamak}.") + print(f"Initialized DB: {db.user}@{db.host}/{db.db_name}") while queries: @@ -40,7 +47,9 @@ print() continue - if not __debug__ or "PYTEST_CURRENT_TEST" in os.environ: + if not __debug__ or any( + k in os.environ for k in ["PYTEST_CURRENT_TEST", "GITHUB_ACTIONS"] + ): break try: diff --git a/tests/conftest.py b/tests/conftest.py index df56eb04..aecb3f01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,52 @@ from unittest.mock import patch import pytest +from disruption_py.utils.mappings.tokamak_helpers import get_tokamak_from_environment, get_tokamak_test_expected_failure_columns, get_tokamak_handler, get_tokamak_test_shot_ids, get_tokamak_test_columns from disruption_py.utils.math_utils import matlab_gradient_1d_vectorized -@pytest.fixture(scope='session', autouse=True) -def mock_numpy_gradient(): - with patch('numpy.gradient', new=matlab_gradient_1d_vectorized): - # The patch will be in place for the duration of the test session - yield - - def pytest_addoption(parser): parser.addoption("--verbose_output", action="store_true", help="More testing information.") - parser.addoption("--fail_slow", action="store_true", help="Finish test and report statistics instead of failing fast.") + parser.addoption("--fail_quick", action="store_true", help="Finish test and report statistics instead of failing fast.") @pytest.fixture(scope="session") def verbose_output(pytestconfig): return pytestconfig.getoption("verbose_output") @pytest.fixture(scope="session") -def fail_slow(pytestconfig): - return pytestconfig.getoption("fail_slow") \ No newline at end of file +def fail_quick(pytestconfig): + return pytestconfig.getoption("fail_quick") + +def pytest_generate_tests(metafunc): + tokamak = get_tokamak_from_environment() + + # parameterized across tests + if "data_column" in metafunc.fixturenames: + test_columns = get_tokamak_test_columns(tokamak) + metafunc.parametrize("data_column", test_columns) + +@pytest.fixture(scope="session") +def tokamak(): + return get_tokamak_from_environment() + +@pytest.fixture(scope="module") +def handler(tokamak): + return get_tokamak_handler(tokamak) + +@pytest.fixture(scope="module") +def shotlist(tokamak): + return get_tokamak_test_shot_ids(tokamak) + +@pytest.fixture(scope="module") +def data_columns(tokamak): + return get_tokamak_test_columns(tokamak) + +@pytest.fixture(scope="module") +def expected_failure_columns(tokamak): + return get_tokamak_test_expected_failure_columns(tokamak) + +# for testing against sql, values generated with matlab use a different gradient method that must be patched for testing +@pytest.fixture(scope='session', autouse=True) +def mock_numpy_gradient(): + with patch('numpy.gradient', new=matlab_gradient_1d_vectorized): + # The patch will be in place for the duration of the test session + yield \ No newline at end of file diff --git a/tests/test_against_sql.py b/tests/test_against_sql.py new file mode 100644 index 00000000..cb75ad2a --- /dev/null +++ b/tests/test_against_sql.py @@ -0,0 +1,100 @@ +"""Unit tests for workflows involving get_dataset_df() for obtaining CMOD data. + +Expects to be run on the MFE workstations. +Expects MDSplus to be installed and configured. +Expects SQL credentials to be configured. +""" +import argparse +from typing import Dict, List +import pytest + +import pandas as pd +from disruption_py.handlers.cmod_handler import Handler +from disruption_py.utils.mappings.tokamak_helpers import get_tokamak_from_environment, get_tokamak_handler, get_tokamak_test_expected_failure_columns, get_tokamak_test_shot_ids +from disruption_py.utils.eval.eval_against_sql import eval_shots_against_sql, get_failure_statistics_string, get_mdsplus_data, get_sql_data_for_mdsplus, eval_against_sql + +@pytest.fixture(scope='module') +def mdsplus_data(handler : Handler, shotlist : List[int]) -> Dict[int, pd.DataFrame]: + return get_mdsplus_data(handler, shotlist) + +@pytest.fixture(scope='module') +def sql_data(handler : Handler, shotlist : List[int], mdsplus_data : Dict[int, pd.DataFrame]) -> Dict[int, pd.DataFrame]: + return get_sql_data_for_mdsplus(handler, shotlist, mdsplus_data) + +def test_data_columns(shotlist : List[int], mdsplus_data : Dict[int, pd.DataFrame], sql_data : Dict[int, pd.DataFrame], data_column, expected_failure_columns : List[str], fail_quick : bool): + """ + Test that the data columns are the same between MDSplus and SQL across specified data columns. + + Data column is parameterized in pytest_generate_tests. + """ + # if data_column in expected_failure_columns: + # request.node.add_marker(pytest.mark.xfail(reason='column expected failure')) + data_differences = eval_shots_against_sql( + shot_ids=shotlist, + mdsplus_data=mdsplus_data, + sql_data=sql_data, + data_columns=[data_column], + expected_failure_columns=expected_failure_columns, # we use xfail instead of manually expecting for column failures + fail_quick=fail_quick + ) + if not fail_quick: + expected_failure = any(data_difference.expect_failure for data_difference in data_differences) + if expected_failure: + pytest.xfail(reason='matches expected data failures') # stops execution of test + else: + assert all(not data_difference.failed for data_difference in data_differences), get_failure_statistics_string( + data_differences, data_column=data_column) + + +def test_other_values(shotlist : List[int], mdsplus_data : Dict[int, pd.DataFrame], sql_data : Dict[int, pd.DataFrame], data_columns : List[str], expected_failure_columns : List[str], fail_quick : bool): + """ + Ensure that all parameters are calculated correctly in the MDSplus shot object. + """ + + mdsplus_columns = set().union(*(df.columns for df in mdsplus_data.values())) + sql_columns = set().union(*(df.columns for df in sql_data.values())) + + test_columns = mdsplus_columns.intersection(sql_columns).difference(data_columns) + + data_differences = eval_shots_against_sql( + shot_ids=shotlist, + mdsplus_data=mdsplus_data, + sql_data=sql_data, + data_columns=test_columns, + fail_quick=fail_quick, + expected_failure_columns=expected_failure_columns + ) + + if not fail_quick: + matches_expected = all(data_difference.matches_expected for data_difference in data_differences) + expected_failure = any(data_difference.expect_failure for data_difference in data_differences) + if matches_expected and expected_failure: + pytest.xfail(reason='matches expected data failures') # stops execution of test + else: + assert all(not data_difference.failed for data_difference in data_differences), get_failure_statistics_string(data_differences) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--fail-slow', action='store_true', help="Get summary of column failures, for specified column(s)") + parser.add_argument('--data-column', default=None, help='Data column to test, use all data columns if not specified') + args = parser.parse_args() + + fail_quick = not args.fail_slow + data_columns = [args.data_column] if args.data_column else None + tokamak = get_tokamak_from_environment() + + handler = get_tokamak_handler(tokamak) + shot_ids = get_tokamak_test_shot_ids(tokamak) + expected_failure_columns = get_tokamak_test_expected_failure_columns(tokamak) + + data_differences = eval_against_sql( + handler=handler, + shot_ids=shot_ids, + expected_failure_columns=expected_failure_columns, + fail_quick=fail_quick, + test_columns=data_columns + ) + + print(get_failure_statistics_string(data_differences)) + diff --git a/tests/test_cmod.py b/tests/test_cmod.py deleted file mode 100644 index 1267811b..00000000 --- a/tests/test_cmod.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Unit tests for workflows involving get_dataset_df() for obtaining CMOD data. - -Expects to be run on the MFE workstations. -Expects MDSplus to be installed and configured. -Expects SQL credentials to be configured. -""" -from typing import Dict -import pytest - -import numpy as np -import pandas as pd -import logging -from disruption_py.handlers.cmod_handler import CModHandler -from disruption_py.settings import ShotSettings, LogSettings -from disruption_py.utils.constants import TIME_CONST - -# Shot list used for testing -# Mix of disruptive and non-disruptive shots present in SQL and MDSplus -TEST_SHOTS = [ - 1150805012, # Flattop Disruption - 1150805013, # No Disruption - 1150805014, # No Disruption - 1150805015, # Rampdown Disruption - 1150805016, # Rampdown Disruption - 1150805017, # Rampdown Disruption - 1150805019, # Rampdown Disruption - 1150805020, # Rampdown Disruption - 1150805021, # Rampdown Disruption - 1150805022 # Flattop Disruption -] - -TEST_COLUMNS = [ - 'I_efc', 'sxr', 'time_until_disrupt', 'beta_n', 'beta_p', 'kappa', 'li', - 'upper_gap', 'lower_gap', 'q0', 'qstar', 'q95', 'v_loop_efit', 'Wmhd', - 'ssep', 'n_over_ncrit', 'tritop', 'tribot', 'a_minor', 'rmagx', 'chisq', - 'dbetap_dt', 'dli_dt', 'dWmhd_dt', 'V_surf', 'kappa_area', 'Te_width', - 'ne_peaking', 'Te_peaking', 'pressure_peaking', 'n_e', 'dn_dt', - 'Greenwald_fraction', 'n_equal_1_mode', 'n_equal_1_normalized', - 'n_equal_1_phase', 'BT', 'prad_peaking', 'v_0', 'ip', 'dip_dt', - 'dip_smoothed', 'ip_prog', 'dipprog_dt', 'ip_error', 'z_error', - 'z_prog', 'zcur', 'v_z', 'z_times_v_z', 'p_oh', 'v_loop', 'p_rad', - 'dprad_dt', 'p_lh', 'p_icrf', 'p_input', 'radiated_fraction', 'time', - 'shot', 'commit_hash' -] - -KNOWN_FAILURE_COLUMNS = [ - 'lower_gap', 'upper_gap', 'ssep', 'dipprog_dt', 'n_over_ncrit', # constant factor scaling error - 'ip_error' # constant error -] - -# TEST_COLUMNS = list(set(TEST_COLUMNS).difference(KNOWN_FAILURE_COLUMNS)) - -TIME_EPSILON = 0.05 # Tolerance for taking the difference between two times [s] -IP_EPSILON = 1e5 # Tolerance for taking the difference between two ip values [A] - -VAL_TOLERANCE = 0.01 # Tolerance for comparing values between MDSplus and SQL -MATCH_FRACTION = 0.95 # Fraction of signals that must match between MDSplus and SQL - -@pytest.fixture(scope='module') -def cmod_handler(): - return CModHandler() - -@pytest.fixture(scope='module') -def shotlist(): - return TEST_SHOTS - -@pytest.fixture(scope='module') -def mdsplus_data(cmod_handler : CModHandler, shotlist) -> Dict: - shot_settings = ShotSettings( - efit_tree_name="efit18", - set_times_request="efit", - log_settings=LogSettings( - log_to_console=False, - log_file_path="tests/cmod.log", - log_file_write_mode="w", - file_log_level=logging.DEBUG - ) - ) - shot_data = cmod_handler.get_shots_data( - shot_ids_request=shotlist, - shot_settings=shot_settings, - output_type_request="dict", - ) - return shot_data - -@pytest.fixture(scope='module') -def sql_data(cmod_handler : CModHandler, shotlist, mdsplus_data : Dict): - shot_data = {} - for shot_id in shotlist: - times = mdsplus_data[shot_id]['time'] - sql_data =cmod_handler.database.get_shots_data([shot_id]) - shot_data[shot_id] = pd.merge_asof(times.to_frame(), sql_data, on='time', direction='nearest', tolerance=TIME_CONST) - assert ( - len(times) == len(shot_data[shot_id]), - f"Shot {shot_id} has {len(times)} rows but SQL has {len(shot_data[shot_id])} rows" - ) - return shot_data - -@pytest.mark.parametrize("data_column", TEST_COLUMNS) -def test_data_columns(shotlist, mdsplus_data : Dict, sql_data : Dict, data_column, verbose_output, fail_slow): - anomaly_ratios = [] - for shot_id in shotlist: - mdsplus_shot_data, sql_shot_data = mdsplus_data[shot_id], sql_data[shot_id] - - if data_column not in mdsplus_shot_data: - print(f"Column {data_column} missing from MDSPlus for shot {shot_id}") - continue - - if data_column not in sql_shot_data: - print(f"Column {data_column} missing from SQL for shot {shot_id}") - continue - - anomaly_ratio = evaluate_differences( - shot_id=shot_id, - sql_shot_data=sql_shot_data, - mdsplus_shot_data=mdsplus_shot_data, - data_column=data_column, - verbose_output=verbose_output, - fail_slow=fail_slow, - ) - anomaly_ratios.append(anomaly_ratio) - - if any(anomaly_ratio['failed'] for anomaly_ratio in anomaly_ratios): - raise AssertionError(get_failure_statistics_string(anomaly_ratios, verbose_output, data_column=data_column)) - -def test_other_values(shotlist, mdsplus_data : Dict, sql_data : Dict, verbose_output, fail_slow): - """ - Ensure that all parameters are calculated correctly in the MDSplus shot object. - """ - anomaly_ratios = [] - for shot_id in shotlist: - mdsplus_shot_data, sql_shot_data = mdsplus_data[shot_id], sql_data[shot_id] - mdsplus_unmatched_cols = list(mdsplus_shot_data.columns.difference(sql_shot_data.columns)) - print(f"Shot {shot_id} is missing {mdsplus_unmatched_cols} from SQL source") - sql_unmatched_cols = list(sql_shot_data.columns.difference(mdsplus_shot_data.columns)) - print(f"Shot {shot_id} is missing {sql_unmatched_cols} from MDSPlus source") - - for data_column in sql_shot_data.columns.intersection(mdsplus_shot_data.columns): - - if data_column in TEST_COLUMNS: - continue - - # check if the col of the shot is all nan - if mdsplus_shot_data[data_column].isna().all() and sql_shot_data[data_column].isna().all(): - continue - - anomaly_ratio = evaluate_differences( - shot_id=shot_id, - sql_shot_data=sql_shot_data, - mdsplus_shot_data=mdsplus_shot_data, - data_column=data_column, - verbose_output=verbose_output, - fail_slow=fail_slow, - ) - anomaly_ratios.append(anomaly_ratio) - - - if any(anomaly_ratio['failed'] for anomaly_ratio in anomaly_ratios): - - raise AssertionError(get_failure_statistics_string(anomaly_ratios, verbose_output)) - -def evaluate_differences(shot_id, sql_shot_data, mdsplus_shot_data, data_column, verbose_output, fail_slow): - # Compare percentage diff between MDSplus and SQL. In the case that the SQL value is 0, inf should be the diff if the MDSplus value is non-zero and nan if the MDSplus value is 0 - relative_difference = np.where( - sql_shot_data[data_column] != 0, - np.abs((mdsplus_shot_data[data_column] - sql_shot_data[data_column]) / sql_shot_data[data_column]), - np.where(mdsplus_shot_data[data_column] != 0, np.inf, np.nan) - ) - numeric_anomalies_mask = (relative_difference > VAL_TOLERANCE) - - sql_is_nan_ = pd.isnull(sql_shot_data[data_column]) - mdsplus_is_nan = pd.isnull(mdsplus_shot_data[data_column]) - nan_anomalies_mask = (sql_is_nan_ != mdsplus_is_nan) - - anomalies = np.argwhere(numeric_anomalies_mask | nan_anomalies_mask) - - if len(anomalies) / len(relative_difference) > 1 - MATCH_FRACTION: - if fail_slow: - failed = True - else: - indexes = np.arange(len(relative_difference)) if verbose_output else anomalies.flatten() - anomaly = np.where(relative_difference > VAL_TOLERANCE, 1, 0)[indexes] - difference_df = pd.DataFrame({ - 'MDSplus Data': mdsplus_shot_data[data_column].iloc[indexes], - 'Reference Data (SQL)': sql_shot_data[data_column].iloc[indexes], - 'Relative difference': relative_difference[indexes], - 'Anomaly': anomaly - }) - pd.options.display.max_rows = None if verbose_output else 10 - raise AssertionError(f"Shot {shot_id} column {data_column} failed for arrays:\n{difference_df}") - else: - failed = False - - anomaly_ratio = { - 'failed': failed, - 'shot': shot_id, - 'data_column': data_column, - 'anomalies': len(anomalies), - 'timebase_length': len(relative_difference), - 'failure_percentage' : f"{len(anomalies) / len(relative_difference*100):.2f}", - } - return anomaly_ratio - -def get_failure_statistics_string(anomaly_ratios, verbose_output, data_column=None): - anomaly_ratio_by_column = {} - for anomaly_ratio in anomaly_ratios: - anomaly_ratio_by_column.setdefault(anomaly_ratio['data_column'], []).append(anomaly_ratio) - - failure_strings = {} - for ratio_data_column, column_anomaly_ratios in anomaly_ratio_by_column.items(): - failures = [anomaly_ratio['shot'] for anomaly_ratio in column_anomaly_ratios if anomaly_ratio['failed']] - failed = len(failures) > 0 - if not verbose_output and not failed: - continue - successes = [anomaly_ratio['shot'] for anomaly_ratio in column_anomaly_ratios if not anomaly_ratio['failed']] - anomaly_count = sum([anomaly_ratio['anomalies'] for anomaly_ratio in column_anomaly_ratios]) - timebase_count = sum([anomaly_ratio['timebase_length'] for anomaly_ratio in column_anomaly_ratios]) - failure_strings[ratio_data_column] = f""" - Column {ratio_data_column} {"FAILED" if failed else "SUCCEEDED"} for shot {anomaly_ratio['shot']} - Failed for {len(failures)} shots: {failures} - Succeeded for {len(successes)} shots: {successes} - Total Entry Failure Rate: {anomaly_count / timebase_count * 100:.2f}% - """ - - if data_column is not None: - return failure_strings.get(data_column, "") - else: - return '\n'.join(failure_strings.values()) - - -# Other tests for MDSplus -# TODO: Refactor these tests - -# @pytest.mark.parametrize("shot_id", TEST_SHOTS[:1]) -# def test_flattop_times(cmod, shot_id): -# """ -# Ensure that the flattop time matches the change in programmed ip in the SQL dataset. -# """ -# sql_df = get_sql_data(cmod, shot_id)[['time', 'dipprog_dt']] -# # Find the first time where dipprog_dt is zero -# sql_flattop_time = sql_df['time'].loc[sql_df['dipprog_dt'] == 0].iloc[0] -# # Find the last time where dipprog_dt is zero -# sql_flattop_end_time = sql_df['time'].loc[sql_df['dipprog_dt'] == 0].iloc[-1] - -# shot_settings = ShotSettings( -# set_times_request='efit', -# signal_domain='flattop' -# ) - -# flattop_df = cmod.get_shots_data(shot_id, shot_settings=shot_settings)[0][['time', 'dipprog_dt']] - -# # Find the first time in the flattop signal -# mds_flattop_time = flattop_df['time'].iloc[0] -# # Find the last time in the flattop signal -# mds_flattop_end_time = flattop_df['time'].iloc[-1] - -# assert mds_flattop_time == pytest.approx(sql_flattop_time, abs=TIME_EPSILON) -# assert mds_flattop_end_time == pytest.approx(sql_flattop_end_time, abs=TIME_EPSILON) - - -# test specific column error in detail -# def test_derrivatives(shotlists): -# test_shots, expected_shots = shotlists -# for shot_id, test_shot_data, expected_shot_data in zip(TEST_SHOTS, test_shots, expected_shots): -# difference_dfs = [] -# for col in ['beta_p', 'dbetap_dt']: -# diff = np.where(expected_shot_data[col] != 0, -# np.abs((test_shot_data[col] - expected_shot_data[col]) / expected_shot_data[col]), -# np.where(test_shot_data[col] != 0, np.inf, np.nan)) -# indexes = np.arange(len(diff)) # anomalies.flatten() -# anomaly_differences = diff[indexes] -# test_shot_data_differences = test_shot_data[col].iloc[indexes] -# expected_shot_data_differences = expected_shot_data[col].iloc[indexes] -# anomaly = np.where(diff > VAL_TOLERANCE, 1, 0)[indexes] -# difference_df = pd.DataFrame({f'Test_{col}': test_shot_data_differences, f'Expected_{col}': expected_shot_data_differences, f'Difference_{col}': anomaly_differences, f'Anomaly_{col}': anomaly}) -# difference_dfs.append(difference_df) - -# total_difference_df = pd.concat(difference_dfs, axis=1) -# # total_difference_df['evals'] = np.gradient(total_difference_df['Expected_beta_p'], np.round(expected_shot_data['time'], 3), edge_order=1) -# # total_difference_df['evals_diff'] = np.where(total_difference_df['Expected_beta_p'] != 0, -# # np.abs((total_difference_df['evals'] - total_difference_df['Expected_dbetap_dt']) / total_difference_df['Expected_dbetap_dt']), -# # np.where(total_difference_df['evals'] != 0, np.inf, np.nan)) -# # total_difference_df['evals_anomaly'] = np.where(total_difference_df['evals_diff'] > VAL_TOLERANCE, 1, 0) -# total_difference_df['beta_p_diff'] = total_difference_df['Expected_beta_p'].diff() -# total_difference_df['time'] = expected_shot_data['time'] -# total_difference_df['time_diff'] = expected_shot_data['time'].diff() - -# total_difference_df.to_csv(f"tests/cmod_failed_values_{shot_id}_dbetap_dt.csv") -# raise AssertionError( -# f"Shot {shot_id} condition failed. Arrays:\n{total_difference_df}" -# ) - - -# test specified columns - -# TEST_COLS = ['beta_n', 'beta_p', 'kappa', 'li', 'upper_gap', 'lower_gap', 'q0', -# 'qstar', 'q95', 'v_loop_efit', 'Wmhd', 'ssep', 'n_over_ncrit', 'V_surf', -# 'R0', 'tritop', 'tribot', 'a_minor', 'chisq', 'dbetap_dt', 'dli_dt', -# 'dWmhd_dt', 'H98', 'Te_width', 'n_e', 'dn_dt', 'Greenwald_fraction', -# 'I_efc', 'ip', 'dip_dt', 'dip_smoothed', 'ip_prog', 'dipprog_dt', -# 'ip_error', 'kappa_area', 'n_equal_1_mode', 'n_equal_1_normalized', -# 'n_equal_1_phase', 'BT', 'p_oh', 'v_loop', 'p_rad', 'dprad_dt', 'p_lh', -# 'p_icrf', 'p_input', 'radiated_fraction', 'v_0', 'sxr', -# 'time_until_disrupt'] - -# def test_get_all_columns(shotlists): -# test_shots, _ = shotlists -# for col in TEST_COLS: -# for shot_data in test_shots: -# assert col in shot_data.columns, f"Shot {shot_data} is missing {col} from MDSplus" - -# @pytest.mark.parametrize("col", TEST_COLS) -# def test_parameter_calculations(shotlists, col): -# """ -# Ensure that all parameters are calculated correctly in the MDSplus shot object. -# """ -# test_shots, expected_shots = shotlists -# for shot_id, test_shot_data, expected_shot_data in zip(TEST_SHOTS, test_shots, expected_shots): -# if col not in expected_shot_data.columns: -# continue -# # check if the col of the shot is all nan -# if test_shot_data[col].isna().all() and expected_shot_data[col].isna().all(): -# continue - -# # Compare percentage diff between MDSplus and SQL. In the case that the SQL value is 0, inf should be the diff if the MDSplus value is non-zero and nan if the MDSplus value is 0 -# diff = np.where(expected_shot_data != 0, -# np.abs((test_shot_data[col] - expected_shot_data[col]) / expected_shot_data[col]), -# np.where(test_shot_data[col] != 0, np.inf, np.nan)) -# anomalies = np.argwhere(diff > 1.e-2) -# assert len(anomalies) / len(diff) < MATCH_FRACTION, f"Shot {shot_id} has {len(anomalies)} anomalies for {col} out of {len(diff)}" diff --git a/tests/test_d3d.py b/tests/test_d3d.py deleted file mode 100644 index 54cf0ba5..00000000 --- a/tests/test_d3d.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Unit tests for workflows involving get_dataset_df() for obtaining D3D data. - -Expects to be run on the MFE workstations. -Expects MDSplus to be installed and configured. -Expects SQL credentials to be configured. -""" -from typing import Dict -import pytest - -import numpy as np -import pandas as pd -import logging -from disruption_py.handlers.d3d_handler import D3DHandler -from disruption_py.settings import ShotSettings, LogSettings -from disruption_py.utils.constants import TIME_CONST - -# Shot list used for testing -# Mix of disruptive and non-disruptive shots present in SQL and MDSplus -TEST_SHOTS = [ - 161228, # disruptive - # 161237, # disruptive - # 166177, # non disruptive - # 166253 -] - -TEST_COLUMNS = [ - 'shot', 'time', 'time_until_disrupt', 'ip_error', 'dip_dt', - 'beta_p', 'beta_n', 'li', 'n_equal_1_mode_IRLM', 'z_error', 'v_z', - 'kappa', 'H98', 'q0', 'qstar', 'q95', 'dn_dt', 'radiated_fraction', - 'power_supply_railed', 'lower_gap', 'upper_gap', 'dbetap_dt', 'dli_dt', - 'ip', 'zcur', 'n_e', 'dipprog_dt', 'v_loop', 'p_rad', 'dWmhd_dt', - 'dprad_dt', 'p_nbi', 'p_ech', 'p_ohm', 'intentional_disruption', - 'Greenwald_fraction', 'Te_HWHM', 'other_hardware_failure', 'Te_HWHM_RT', - 'v_loop_RT', 'n_e_RT', 'Greenwald_fraction_RT', 'ip_error_RT', 'ip_RT', - 'dipprog_dt_RT', 'Wmhd_RT', 'Wmhd', 'n_equal_1_mode', - 'n_equal_1_normalized', 'Te_width_normalized', 'Te_width_normalized_RT', - 'q95_RT', 'li_RT', 'beta_p_RT', 'oeamp1em', 'oeamp1om', 'oefrq1em', - 'oefrq1om', 'oeamp1e', 'oeamp1o', 'oefrq1e', 'oefrq1o', 'delta', - 'squareness', 'zcur_normalized', 'aminor', 'n1rms_normalized', - 'kappa_area', 'Te_peaking_CVA_RT', 'ne_peaking_CVA_RT', - 'Prad_peaking_CVA_RT', 'Prad_peaking_XDIV_RT', 'H_alpha', -] - -KNOWN_FAILURE_COLUMNS = [] - -# TEST_COLUMNS = list(set(TEST_COLUMNS).difference(KNOWN_FAILURE_COLUMNS)) - -TIME_EPSILON = 0.05 # Tolerance for taking the difference between two times [s] -IP_EPSILON = 1e5 # Tolerance for taking the difference between two ip values [A] - -VAL_TOLERANCE = 0.01 # Tolerance for comparing values between MDSplus and SQL -MATCH_FRACTION = 0.95 # Fraction of signals that must match between MDSplus and SQL - -@pytest.fixture(scope='module') -def d3d_handler(): - return D3DHandler() - -@pytest.fixture(scope='module') -def shotlist(): - return TEST_SHOTS - -@pytest.fixture(scope='module') -def mdsplus_data(d3d_handler : D3DHandler, shotlist) -> Dict: - shot_settings = ShotSettings( - set_times_request="disruption", - log_settings=LogSettings( - log_to_console=False, - log_file_path="tests/d3d.log", - log_file_write_mode="w", - file_log_level=logging.DEBUG - ) - ) - shot_data = d3d_handler.get_shots_data( - shot_ids_request=shotlist, - shot_settings=shot_settings, - output_type_request="dict", - ) - return shot_data - -@pytest.fixture(scope='module') -def sql_data(d3d_handler : D3DHandler, shotlist, mdsplus_data : Dict): - shot_data = {} - for shot_id in shotlist: - times = mdsplus_data[shot_id]['time'] - sql_data =d3d_handler.database.get_shots_data([shot_id]) - assert len(times) == len(sql_data), f"Shot {shot_id} has {len(times)} rows but SQL has {len(sql_data)} rows" - - shot_data[shot_id] = pd.merge_asof(times.to_frame(), sql_data, on='time', direction='nearest', tolerance=TIME_CONST) - assert len(times) == len(shot_data[shot_id]), f"After projecting timebase shot {shot_id} has {len(times)} rows but SQL has {len(shot_data[shot_id])} rows" - - return shot_data - -@pytest.mark.parametrize("data_column", TEST_COLUMNS) -def test_data_columns(shotlist, mdsplus_data : Dict, sql_data : Dict, data_column, verbose_output, fail_slow): - anomaly_ratios = [] - for shot_id in shotlist: - mdsplus_shot_data, sql_shot_data = mdsplus_data[shot_id], sql_data[shot_id] - - if data_column not in mdsplus_shot_data: - raise AssertionError(f"Column {data_column} missing from MDSPlus for shot {shot_id}") - - if data_column not in sql_shot_data: - raise AssertionError(f"Column {data_column} missing from SQL for shot {shot_id}") - - anomaly_ratio = evaluate_differences( - shot_id=shot_id, - sql_shot_data=sql_shot_data, - mdsplus_shot_data=mdsplus_shot_data, - data_column=data_column, - verbose_output=verbose_output, - fail_slow=fail_slow, - ) - anomaly_ratios.append(anomaly_ratio) - - if any(anomaly_ratio['failed'] for anomaly_ratio in anomaly_ratios): - raise AssertionError(get_failure_statistics_string(anomaly_ratios, verbose_output, data_column=data_column)) - -def test_other_values(shotlist, mdsplus_data : Dict, sql_data : Dict, verbose_output, fail_slow): - """ - Ensure that all parameters are calculated correctly in the MDSplus shot object. - """ - anomaly_ratios = [] - for shot_id in shotlist: - mdsplus_shot_data, sql_shot_data = mdsplus_data[shot_id], sql_data[shot_id] - mdsplus_unmatched_cols = list(mdsplus_shot_data.columns.difference(sql_shot_data.columns)) - print(f"Shot {shot_id} is missing {mdsplus_unmatched_cols} from SQL source") - sql_unmatched_cols = list(sql_shot_data.columns.difference(mdsplus_shot_data.columns)) - print(f"Shot {shot_id} is missing {sql_unmatched_cols} from MDSPlus source") - - for data_column in sql_shot_data.columns.intersection(mdsplus_shot_data.columns): - - if data_column in TEST_COLUMNS: - continue - - # check if the col of the shot is all nan - if mdsplus_shot_data[data_column].isna().all() and sql_shot_data[data_column].isna().all(): - continue - - anomaly_ratio = evaluate_differences( - shot_id=shot_id, - sql_shot_data=sql_shot_data, - mdsplus_shot_data=mdsplus_shot_data, - data_column=data_column, - verbose_output=verbose_output, - fail_slow=fail_slow, - ) - anomaly_ratios.append(anomaly_ratio) - - - if any(anomaly_ratio['failed'] for anomaly_ratio in anomaly_ratios): - - raise AssertionError(get_failure_statistics_string(anomaly_ratios, verbose_output)) - -def evaluate_differences(shot_id, sql_shot_data, mdsplus_shot_data, data_column, verbose_output, fail_slow): - # Compare percentage diff between MDSplus and SQL. In the case that the SQL value is 0, inf should be the diff if the MDSplus value is non-zero and nan if the MDSplus value is 0 - relative_difference = np.where( - sql_shot_data[data_column] != 0, - np.abs((mdsplus_shot_data[data_column] - sql_shot_data[data_column]) / sql_shot_data[data_column]), - np.where(mdsplus_shot_data[data_column] != 0, np.inf, np.nan) - ) - numeric_anomalies_mask = (relative_difference > VAL_TOLERANCE) - - sql_is_nan_ = pd.isnull(sql_shot_data[data_column]) - mdsplus_is_nan = pd.isnull(mdsplus_shot_data[data_column]) - nan_anomalies_mask = (sql_is_nan_ != mdsplus_is_nan) - - anomalies = np.argwhere(numeric_anomalies_mask | nan_anomalies_mask) - - if len(anomalies) / len(relative_difference) > 1 - MATCH_FRACTION: - if fail_slow: - failed = True - else: - indexes = np.arange(len(relative_difference)) if verbose_output else anomalies.flatten() - anomaly = np.where(relative_difference > VAL_TOLERANCE, 1, 0)[indexes] - difference_df = pd.DataFrame({ - 'MDSplus Data': mdsplus_shot_data[data_column].iloc[indexes], - 'Reference Data (SQL)': sql_shot_data[data_column].iloc[indexes], - 'Relative difference': relative_difference[indexes], - 'Anomaly': anomaly - }) - pd.options.display.max_rows = None if verbose_output else 10 - raise AssertionError(f"Shot {shot_id} column {data_column} failed for arrays:\n{difference_df}") - else: - failed = False - - anomaly_ratio = { - 'failed': failed, - 'shot': shot_id, - 'data_column': data_column, - 'anomalies': len(anomalies), - 'timebase_length': len(relative_difference), - 'failure_percentage' : f"{len(anomalies) / len(relative_difference*100):.2f}", - } - return anomaly_ratio - -def get_failure_statistics_string(anomaly_ratios, verbose_output, data_column=None): - anomaly_ratio_by_column = {} - for anomaly_ratio in anomaly_ratios: - anomaly_ratio_by_column.setdefault(anomaly_ratio['data_column'], []).append(anomaly_ratio) - - failure_strings = {} - for ratio_data_column, column_anomaly_ratios in anomaly_ratio_by_column.items(): - failures = [anomaly_ratio['shot'] for anomaly_ratio in column_anomaly_ratios if anomaly_ratio['failed']] - failed = len(failures) > 0 - if not verbose_output and not failed: - continue - successes = [anomaly_ratio['shot'] for anomaly_ratio in column_anomaly_ratios if not anomaly_ratio['failed']] - anomaly_count = sum([anomaly_ratio['anomalies'] for anomaly_ratio in column_anomaly_ratios]) - timebase_count = sum([anomaly_ratio['timebase_length'] for anomaly_ratio in column_anomaly_ratios]) - failure_strings[ratio_data_column] = f""" - Column {ratio_data_column} {"FAILED" if failed else "SUCCEEDED"} for shot {anomaly_ratio['shot']} - Failed for {len(failures)} shots: {failures} - Succeeded for {len(successes)} shots: {successes} - Total Entry Failure Rate: {anomaly_count / timebase_count * 100:.2f}% - """ - - if data_column is not None: - return failure_strings.get(data_column, "") - else: - return '\n'.join(failure_strings.values()) diff --git a/tests/test_cmod_features.py b/tests/test_features.py similarity index 60% rename from tests/test_cmod_features.py rename to tests/test_features.py index 47041d91..3c54611c 100644 --- a/tests/test_cmod_features.py +++ b/tests/test_features.py @@ -3,18 +3,10 @@ import os import pandas as pd -from disruption_py.handlers.cmod_handler import CModHandler +from disruption_py.handlers.handler import Handler from disruption_py.settings.log_settings import LogSettings from disruption_py.settings.shot_settings import ShotSettings -TEST_SHOTS = { - "flattop_fast": 1150805012, - "nodisrup1_full": 1150805013, - "nodisrup2_full": 1150805014, - "rampdown1_full": 1150805015, - "rampdown2_full": 1150805016, -} - TEST_SETTINGS = { "default_fast": ShotSettings(), "sqlcache_full": ShotSettings(existing_data_request="sql"), @@ -35,10 +27,12 @@ run_tags=[], run_methods=["_get_ip_parameters"], ), - "rampup_fast": ShotSettings( - efit_tree_name="analysis", - signal_domain="rampup_and_flattop", - ), + "rampup_fast": { + "cmod": ShotSettings( + efit_tree_name="analysis", + signal_domain="rampup_and_flattop", + ) + }, "logging_full": ShotSettings( log_settings=LogSettings( log_file_path=f"{__file__}.log", @@ -53,38 +47,23 @@ ), } - -@pytest.fixture(scope="module") -def cmod_handler(): - return CModHandler() - - -@pytest.fixture(scope="module") -def shot_list(): - if "GITHUB_ACTIONS" in os.environ: - return [v for k, v in TEST_SHOTS.items() if k.endswith("_fast")] - return list(TEST_SHOTS.values()) - - -@pytest.fixture(scope="module") -def shot_settings_keys(): - if "GITHUB_ACTIONS" in os.environ: - return [k for k in TEST_SETTINGS if k.endswith("_fast")] - return TEST_SETTINGS.keys() - - -@pytest.mark.skipif( - os.path.exists("/fusion/projects/disruption_warning"), reason="on DIII-D" -) @pytest.mark.parametrize("shot_settings_key", TEST_SETTINGS.keys()) def test_features_serial( - cmod_handler, shot_list, shot_settings_key, shot_settings_keys + handler : Handler, tokamak, shotlist, shot_settings_key ): - if shot_settings_key not in shot_settings_keys: + if "GITHUB_ACTIONS" in os.environ and "_fast" not in shot_settings_key: pytest.skip("fast execution") - results = cmod_handler.get_shots_data( - shot_ids_request=shot_list, - shot_settings=TEST_SETTINGS[shot_settings_key], + + test_setting = TEST_SETTINGS[shot_settings_key] + if isinstance(test_setting, dict): + if tokamak.value in test_setting: + test_setting = test_setting[tokamak.value] + else: + pytest.skip(f"not tested for tokamak {tokamak.value}") + + results = handler.get_shots_data( + shot_ids_request=shotlist, + shot_settings=test_setting, output_type_request=[ "list", "dataframe", @@ -96,15 +75,12 @@ def test_features_serial( list_output, df_output, csv_processed, hdf_processed = results assert isinstance(list_output, list) assert isinstance(df_output, pd.DataFrame) - assert csv_processed == hdf_processed == len(shot_list) + assert csv_processed == hdf_processed == len(shotlist) -@pytest.mark.skipif( - os.path.exists("/fusion/projects/disruption_warning"), reason="on DIII-D" -) -def test_features_parallel(cmod_handler, shot_list): - results = cmod_handler.get_shots_data( - shot_ids_request=shot_list, +def test_features_parallel(handler : Handler, shotlist): + results = handler.get_shots_data( + shot_ids_request=shotlist, shot_settings=TEST_SETTINGS["default_fast"], output_type_request=[ "list", @@ -117,7 +93,7 @@ def test_features_parallel(cmod_handler, shot_list): list_output, df_output, csv_processed, hdf_processed = results assert isinstance(list_output, list) assert isinstance(df_output, pd.DataFrame) - assert csv_processed == hdf_processed == len(shot_list) + assert csv_processed == hdf_processed == len(shotlist) @pytest.fixture(scope="session", autouse=True)