Skip to content

Commit

Permalink
Don't assume order and length is the same between observed and predic…
Browse files Browse the repository at this point in the history
…ted features (#3250)

Summary:

Combining diffs D68274239 & D68294872:

1 - **D68274239 - [analysis] Don't assume order is the same in observed and predicted features**

This is causing only 3 of 14 arms to show on an experiment.  We're still assuming that within an observation the order of the metric names matches the order of the corresponding data.

2 - **D68294872 - [analysis] Don't assume observed and predicted metrics are the same length**

Comment from  [here](https://www.internalfb.com/diff/D68274239?dst_version_fbid=9435339503152040&transaction_fbid=1415387593232440):

In N6432597 I encounter a `StopIteration` error when rebased on D68274239. I believe this is because `predicted.metric_names` is quite a bit longer than `observed.data.metric_names` and so there's a chance that

```
predicted_i = next(
            i
            for i in range(len(observed.data.metric_names))
            if predicted.metric_names[i] == metric_name
        )
```

never finds the metric it needs and throws the `StopIteration` error.

Reviewed By: danielcohenlive

Differential Revision: D68336952
  • Loading branch information
eonofrey authored and facebook-github-bot committed Feb 3, 2025
1 parent 6ae638c commit 2bfef3e
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,30 @@ def _prepare_data(
f"trial {trial_index}, but has observations from trial "
f"{observed.features.trial_index}."
)
for i in range(len(observed.data.metric_names)):
# Find the index of the metric we want to plot
if not (
observed.data.metric_names[i] == metric_name
and predicted.metric_names[i] == metric_name
):
continue

# Find the index of the metric in observed and predicted
observed_i = next(
(
i
for i, name in enumerate(observed.data.metric_names)
if name == metric_name
),
None,
)
predicted_i = next(
(i for i, name in enumerate(predicted.metric_names) if name == metric_name),
None,
)
# Check if both indices are found
if observed_i is not None and predicted_i is not None:
record = {
"arm_name": observed.arm_name,
"observed": observed.data.means[i],
"predicted": predicted.means[i],
# Take the square root of the the SEM to get the standard deviation
"observed_sem": observed.data.covariance[i][i] ** 0.5,
"predicted_sem": predicted.covariance[i][i] ** 0.5,
"observed": observed.data.means[observed_i],
"predicted": predicted.means[predicted_i],
# Take the square root of the SEM to get the standard deviation
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
}
records.append(record)
break

return pd.DataFrame.from_records(records)


Expand Down

0 comments on commit 2bfef3e

Please sign in to comment.