Skip to content

Commit

Permalink
Implement compute_analyses (#3143)
Browse files Browse the repository at this point in the history
Summary:

Implement a method which given a list of Analyses computes each and returns the resulting cards. If an Exception is encountered a MarkdownAnalysisCard with the traceback is created instead with priority DEBUG. If no list of Analyses is provided defaults are chosen.

Reviewed By: lena-kashtelyan

Differential Revision: D66677557
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 6, 2024
1 parent 5ac3c71 commit 3d51ca6
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 3 deletions.
4 changes: 3 additions & 1 deletion ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# pyre-strict

from __future__ import annotations

import json
from collections.abc import Iterable
from enum import IntEnum
Expand Down Expand Up @@ -132,7 +134,7 @@ def compute_result(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
) -> Result[AnalysisCard, ExceptionE]:
) -> Result[AnalysisCard, AnalysisE]:
"""
Utility method to compute an AnalysisCard as a Result. This can be useful for
computing many Analyses at once and handling Exceptions later.
Expand Down
24 changes: 23 additions & 1 deletion ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
# pyre-strict


import traceback

import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from IPython.display import display, Markdown
Expand Down Expand Up @@ -59,3 +61,23 @@ def _create_markdown_analysis_card(
df=df,
blob=message,
)


def markdown_analysis_card_from_analysis_e(
analysis_e: AnalysisE,
) -> MarkdownAnalysisCard:
return MarkdownAnalysisCard(
name=analysis_e.analysis.name,
title=f"{analysis_e.analysis.name} Error",
subtitle=f"An error occurred while computing {analysis_e.analysis}",
attributes=analysis_e.analysis.attributes,
blob="".join(
traceback.format_exception(
type(analysis_e.exception),
analysis_e.exception,
analysis_e.exception.__traceback__,
)
),
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
)
36 changes: 36 additions & 0 deletions ax/analysis/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.analysis.utils import choose_analyses
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_experiment_with_multi_objective,
)


class TestUtils(TestCase):
def test_choose_analyses(self) -> None:
analyses = choose_analyses(experiment=get_branin_experiment())
self.assertEqual(
{analysis.name for analysis in analyses},
{
"ParallelCoordinatesPlot",
"InteractionPlot",
"Summary",
"CrossValidationPlot",
},
)

# Multi-objective case
analyses = choose_analyses(
experiment=get_branin_experiment_with_multi_objective()
)
self.assertEqual(
{analysis.name for analysis in analyses},
{"InteractionPlot", "ScatterPlot", "Summary", "CrossValidationPlot"},
)
76 changes: 76 additions & 0 deletions ax/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools

from ax.analysis.analysis import Analysis
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.analysis.plotly.interaction import InteractionPlot
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.analysis.plotly.scatter import ScatterPlot
from ax.analysis.summary import Summary
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, ScalarizedObjective


def choose_analyses(experiment: Experiment) -> list[Analysis]:
"""
Choose a default set of Analyses to compute based on the current state of the
Experiment.
"""
if (optimization_config := experiment.optimization_config) is None:
return []

if isinstance(optimization_config.objective, MultiObjective) or isinstance(
optimization_config.objective, ScalarizedObjective
):
# Pareto frontiers for each objective
objective_plots = [
*[
ScatterPlot(x_metric_name=x, y_metric_name=y, show_pareto_frontier=True)
for x, y in itertools.combinations(
optimization_config.objective.metric_names, 2
)
],
]

other_scatters = []

interactions = [
InteractionPlot(metric_name=name)
for name in optimization_config.objective.metric_names
]

else:
objective_name = optimization_config.objective.metric.name
# ParallelCoorindates and leave-one-out cross validation
objective_plots = [
ParallelCoordinatesPlot(metric_name=objective_name),
]

# Up to six ScatterPlots for other metrics versus the objective,
# prioritizing optimization config metrics over tracking metrics
tracking_metric_names = [metric.name for metric in experiment.tracking_metrics]
other_scatters = [
ScatterPlot(
x_metric_name=objective_name,
y_metric_name=name,
show_pareto_frontier=False,
)
for name in [
*optimization_config.metrics,
*tracking_metric_names,
]
if name != objective_name
][:6]

interactions = [InteractionPlot(metric_name=objective_name)]

# Leave-one-out cross validation for each objective and outcome constraint
cv_plots = [
CrossValidationPlot(metric_name=name) for name in optimization_config.metrics
]

return [*objective_plots, *other_scatters, *interactions, *cv_plots, Summary()]
33 changes: 32 additions & 1 deletion ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import numpy as np

from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type
from ax.analysis.markdown.markdown_analysis import (
markdown_analysis_card_from_analysis_e,
)
from ax.analysis.utils import choose_analyses

from ax.core.base_trial import TrialStatus # Used as a return type

Expand Down Expand Up @@ -634,7 +638,34 @@ def compute_analyses(
Returns:
A list of AnalysisCards.
"""
...

analyses = (
analyses
if analyses is not None
else choose_analyses(experiment=self._none_throws_experiment())
)

# Compute Analyses one by one and accumulate Results holding either the
# AnalysisCard or an Exception and some metadata
results = [
analysis.compute_result(
experiment=self._none_throws_experiment(),
generation_strategy=self._generation_strategy_or_choose(),
)
for analysis in analyses
]

# Turn Exceptions into MarkdownAnalysisCards with the traceback as the message
cards = [
result.unwrap_or_else(markdown_analysis_card_from_analysis_e)
for result in results
]

if self._db_config is not None:
# TODO[mpolson64] Save cards to database
...

return cards

def get_best_trial(
self, use_model_predictions: bool = True
Expand Down
46 changes: 46 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

import pandas as pd
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot

from ax.core.base_trial import TrialStatus

Expand Down Expand Up @@ -865,6 +866,51 @@ def test_get_next_trials_then_run_trials(self) -> None:
5,
)

def test_compute_analyses(self) -> None:
client = Client()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")

with self.assertLogs(logger="ax.analysis", level="ERROR") as lg:
cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot")
self.assertEqual(cards[0].title, "ParallelCoordinatesPlot Error")
self.assertEqual(
cards[0].subtitle,
f"An error occurred while computing {ParallelCoordinatesPlot()}",
)
self.assertIn("Traceback", cards[0].blob)
self.assertTrue(
any(
(
"Failed to compute ParallelCoordinatesPlot: "
"No data found for metric "
)
in msg
for msg in lg.output
)
)

for trial_index, _ in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(trial_index=trial_index, raw_data={"foo": 1.0})

cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot")


class DummyRunner(IRunner):
@override
Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,11 @@ Utils
:members:
:undoc-members:
:show-inheritance:

Utils
~~~~~

.. automodule:: ax.analysis.utils
:members:
:undoc-members:
:show-inheritance:

0 comments on commit 3d51ca6

Please sign in to comment.