From 3d51ca6bb5a7d9feb98f0c034dc3c98948edace3 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Fri, 6 Dec 2024 12:23:21 -0800 Subject: [PATCH] Implement compute_analyses (#3143) Summary: Implement a method which given a list of Analyses computes each and returns the resulting cards. If an Exception is encountered a MarkdownAnalysisCard with the traceback is created instead with priority DEBUG. If no list of Analyses is provided defaults are chosen. Reviewed By: lena-kashtelyan Differential Revision: D66677557 --- ax/analysis/analysis.py | 4 +- ax/analysis/markdown/markdown_analysis.py | 24 ++++++- ax/analysis/tests/test_utils.py | 36 +++++++++++ ax/analysis/utils.py | 76 +++++++++++++++++++++++ ax/preview/api/client.py | 33 +++++++++- ax/preview/api/tests/test_client.py | 46 ++++++++++++++ sphinx/source/analysis.rst | 8 +++ 7 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 ax/analysis/tests/test_utils.py create mode 100644 ax/analysis/utils.py diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index 5d079802455..c738b1d51e2 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -5,6 +5,8 @@ # pyre-strict +from __future__ import annotations + import json from collections.abc import Iterable from enum import IntEnum @@ -132,7 +134,7 @@ def compute_result( self, experiment: Experiment | None = None, generation_strategy: GenerationStrategyInterface | None = None, - ) -> Result[AnalysisCard, ExceptionE]: + ) -> Result[AnalysisCard, AnalysisE]: """ Utility method to compute an AnalysisCard as a Result. This can be useful for computing many Analyses at once and handling Exceptions later. diff --git a/ax/analysis/markdown/markdown_analysis.py b/ax/analysis/markdown/markdown_analysis.py index 75393630a0d..c8958d59838 100644 --- a/ax/analysis/markdown/markdown_analysis.py +++ b/ax/analysis/markdown/markdown_analysis.py @@ -6,8 +6,10 @@ # pyre-strict +import traceback + import pandas as pd -from ax.analysis.analysis import Analysis, AnalysisCard +from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from IPython.display import display, Markdown @@ -59,3 +61,23 @@ def _create_markdown_analysis_card( df=df, blob=message, ) + + +def markdown_analysis_card_from_analysis_e( + analysis_e: AnalysisE, +) -> MarkdownAnalysisCard: + return MarkdownAnalysisCard( + name=analysis_e.analysis.name, + title=f"{analysis_e.analysis.name} Error", + subtitle=f"An error occurred while computing {analysis_e.analysis}", + attributes=analysis_e.analysis.attributes, + blob="".join( + traceback.format_exception( + type(analysis_e.exception), + analysis_e.exception, + analysis_e.exception.__traceback__, + ) + ), + df=pd.DataFrame(), + level=AnalysisCardLevel.DEBUG, + ) diff --git a/ax/analysis/tests/test_utils.py b/ax/analysis/tests/test_utils.py new file mode 100644 index 00000000000..1e9f4c4f2a3 --- /dev/null +++ b/ax/analysis/tests/test_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.utils import choose_analyses +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_experiment_with_multi_objective, +) + + +class TestUtils(TestCase): + def test_choose_analyses(self) -> None: + analyses = choose_analyses(experiment=get_branin_experiment()) + self.assertEqual( + {analysis.name for analysis in analyses}, + { + "ParallelCoordinatesPlot", + "InteractionPlot", + "Summary", + "CrossValidationPlot", + }, + ) + + # Multi-objective case + analyses = choose_analyses( + experiment=get_branin_experiment_with_multi_objective() + ) + self.assertEqual( + {analysis.name for analysis in analyses}, + {"InteractionPlot", "ScatterPlot", "Summary", "CrossValidationPlot"}, + ) diff --git a/ax/analysis/utils.py b/ax/analysis/utils.py new file mode 100644 index 00000000000..7438f598a9c --- /dev/null +++ b/ax/analysis/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +from ax.analysis.analysis import Analysis +from ax.analysis.plotly.cross_validation import CrossValidationPlot +from ax.analysis.plotly.interaction import InteractionPlot +from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot +from ax.analysis.plotly.scatter import ScatterPlot +from ax.analysis.summary import Summary +from ax.core.experiment import Experiment +from ax.core.objective import MultiObjective, ScalarizedObjective + + +def choose_analyses(experiment: Experiment) -> list[Analysis]: + """ + Choose a default set of Analyses to compute based on the current state of the + Experiment. + """ + if (optimization_config := experiment.optimization_config) is None: + return [] + + if isinstance(optimization_config.objective, MultiObjective) or isinstance( + optimization_config.objective, ScalarizedObjective + ): + # Pareto frontiers for each objective + objective_plots = [ + *[ + ScatterPlot(x_metric_name=x, y_metric_name=y, show_pareto_frontier=True) + for x, y in itertools.combinations( + optimization_config.objective.metric_names, 2 + ) + ], + ] + + other_scatters = [] + + interactions = [ + InteractionPlot(metric_name=name) + for name in optimization_config.objective.metric_names + ] + + else: + objective_name = optimization_config.objective.metric.name + # ParallelCoorindates and leave-one-out cross validation + objective_plots = [ + ParallelCoordinatesPlot(metric_name=objective_name), + ] + + # Up to six ScatterPlots for other metrics versus the objective, + # prioritizing optimization config metrics over tracking metrics + tracking_metric_names = [metric.name for metric in experiment.tracking_metrics] + other_scatters = [ + ScatterPlot( + x_metric_name=objective_name, + y_metric_name=name, + show_pareto_frontier=False, + ) + for name in [ + *optimization_config.metrics, + *tracking_metric_names, + ] + if name != objective_name + ][:6] + + interactions = [InteractionPlot(metric_name=objective_name)] + + # Leave-one-out cross validation for each objective and outcome constraint + cv_plots = [ + CrossValidationPlot(metric_name=name) for name in optimization_config.metrics + ] + + return [*objective_plots, *other_scatters, *interactions, *cv_plots, Summary()] diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 3560222c946..8230bbba024 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -11,6 +11,10 @@ import numpy as np from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type +from ax.analysis.markdown.markdown_analysis import ( + markdown_analysis_card_from_analysis_e, +) +from ax.analysis.utils import choose_analyses from ax.core.base_trial import TrialStatus # Used as a return type @@ -634,7 +638,34 @@ def compute_analyses( Returns: A list of AnalysisCards. """ - ... + + analyses = ( + analyses + if analyses is not None + else choose_analyses(experiment=self._none_throws_experiment()) + ) + + # Compute Analyses one by one and accumulate Results holding either the + # AnalysisCard or an Exception and some metadata + results = [ + analysis.compute_result( + experiment=self._none_throws_experiment(), + generation_strategy=self._generation_strategy_or_choose(), + ) + for analysis in analyses + ] + + # Turn Exceptions into MarkdownAnalysisCards with the traceback as the message + cards = [ + result.unwrap_or_else(markdown_analysis_card_from_analysis_e) + for result in results + ] + + if self._db_config is not None: + # TODO[mpolson64] Save cards to database + ... + + return cards def get_best_trial( self, use_model_predictions: bool = True diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index bb8a9742c78..db396c7dae8 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd +from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.base_trial import TrialStatus @@ -865,6 +866,51 @@ def test_get_next_trials_then_run_trials(self) -> None: 5, ) + def test_compute_analyses(self) -> None: + client = Client() + + client.configure_experiment( + ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + ], + name="foo", + ) + ) + client.configure_optimization(objective="foo") + + with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: + cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 1) + self.assertEqual(cards[0].name, "ParallelCoordinatesPlot") + self.assertEqual(cards[0].title, "ParallelCoordinatesPlot Error") + self.assertEqual( + cards[0].subtitle, + f"An error occurred while computing {ParallelCoordinatesPlot()}", + ) + self.assertIn("Traceback", cards[0].blob) + self.assertTrue( + any( + ( + "Failed to compute ParallelCoordinatesPlot: " + "No data found for metric " + ) + in msg + for msg in lg.output + ) + ) + + for trial_index, _ in client.get_next_trials(maximum_trials=1).items(): + client.complete_trial(trial_index=trial_index, raw_data={"foo": 1.0}) + + cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 1) + self.assertEqual(cards[0].name, "ParallelCoordinatesPlot") + class DummyRunner(IRunner): @override diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst index 5df991dda22..3561fb1ff68 100644 --- a/sphinx/source/analysis.rst +++ b/sphinx/source/analysis.rst @@ -159,3 +159,11 @@ Utils :members: :undoc-members: :show-inheritance: + +Utils +~~~~~ + +.. automodule:: ax.analysis.utils + :members: + :undoc-members: + :show-inheritance: