diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 4214a5156e0..ae954265436 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -161,25 +161,30 @@ def _prepare_data( f"trial {trial_index}, but has observations from trial " f"{observed.features.trial_index}." ) - for i in range(len(observed.data.metric_names)): - # Find the index of the metric we want to plot - if not ( - observed.data.metric_names[i] == metric_name - and predicted.metric_names[i] == metric_name - ): - continue - + # Find the index of the metric in observed and predicted + observed_i = next( + ( + i + for i, name in enumerate(observed.data.metric_names) + if name == metric_name + ), + None, + ) + predicted_i = next( + (i for i, name in enumerate(predicted.metric_names) if name == metric_name), + None, + ) + # Check if both indices are found + if observed_i is not None and predicted_i is not None: record = { "arm_name": observed.arm_name, - "observed": observed.data.means[i], - "predicted": predicted.means[i], - # Take the square root of the the SEM to get the standard deviation - "observed_sem": observed.data.covariance[i][i] ** 0.5, - "predicted_sem": predicted.covariance[i][i] ** 0.5, + "observed": observed.data.means[observed_i], + "predicted": predicted.means[predicted_i], + # Take the square root of the SEM to get the standard deviation + "observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5, + "predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5, } records.append(record) - break - return pd.DataFrame.from_records(records)