From d85cc09f999305eaa655b1569f7ff67fc11ebf97 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 5 Dec 2024 21:06:54 -0800 Subject: [PATCH] Add `count_only_trials_with_data` to trial-based TCs (#3012) Summary: For some GenerationStrategies we only want to progress if the trials in the expected statuses also actually have data, this extends trial based criterion to consider if we have data or not Reviewed By: lena-kashtelyan Differential Revision: D64617987 --- ax/modelbridge/tests/test_generation_node.py | 3 ++- ...test_generation_node_input_constructors.py | 1 - .../tests/test_transition_criterion.py | 6 +++-- ax/modelbridge/transition_criterion.py | 23 ++++++++++++++++++- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index b682490f65b..9667f808ddc 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -314,7 +314,8 @@ def test_node_string_representation(self) -> None: "'not_in_statuses': None, 'transition_to': None, " "'block_transition_if_unmet': True, 'block_gen_if_met': False, " "'use_all_trials_in_exp': False, " - "'continue_trial_generation': False})])" + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})])" ), ) diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 732de6a1338..3c4ae6ac2e4 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -113,7 +113,6 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None: untested_constructors.remove(constructor) # There should be no untested constructors left. - print(untested_constructors) self.assertEqual(len(untested_constructors), 0) def test_consume_all_n_constructor(self) -> None: diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 78ca0e50032..3dc5e334890 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -580,7 +580,8 @@ def test_repr(self) -> None: + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'continue_trial_generation': False, " + + "'count_only_trials_with_data': False})", ) minimum_trials_in_status_criterion = MinTrials( only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], @@ -599,7 +600,8 @@ def test_repr(self) -> None: + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'continue_trial_generation': False, " + + "'count_only_trials_with_data': False})", ) minimum_preference_occurrences_criterion = MinimumPreferenceOccurances( metric_name="m1", threshold=3 diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index 7727a33fc80..19e44906fa9 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -31,7 +31,6 @@ class TransitionCriterion(SortableBase, SerializationMixin): - # TODO: @mgarrard rename to ActionCriterion """ Simple class to describe a condition which must be met for this GenerationNode to take an action such as generation, transition, etc. @@ -262,6 +261,8 @@ class TrialBasedCriterion(TransitionCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -274,11 +275,13 @@ def __init__( transition_to: str | None = None, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: self.threshold = threshold self.only_in_statuses = only_in_statuses self.not_in_statuses = not_in_statuses self.use_all_trials_in_exp = use_all_trials_in_exp + self.count_only_trials_with_data = count_only_trials_with_data super().__init__( transition_to=transition_to, block_transition_if_unmet=block_transition_if_unmet, @@ -334,6 +337,14 @@ def num_contributing_to_threshold( trials_from_node: The set of trials generated by this GenerationNode. """ all_trials_to_check = self.all_trials_to_check(experiment=experiment) + if self.count_only_trials_with_data: + all_trials_to_check = { + trial_index + for trial_index in all_trials_to_check + # TODO[@mgarrard]: determine if we need to actually check data with + # more granularity, e.g. number of days of data, etc. + if trial_index in experiment.data_by_trial + } # Some criteria may rely on experiment level data, instead of only trials # generated from the node associated with the criterion. if self.use_all_trials_in_exp: @@ -415,6 +426,8 @@ class MaxGenerationParallelism(TrialBasedCriterion): create a given ``BatchTrial``. Defaults to False for MaxGenerationParallelism since this criterion isn't currently used for node -> node or trial -> trial transition. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -497,6 +510,8 @@ class MaxTrials(TrialBasedCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -509,6 +524,7 @@ def __init__( block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: super().__init__( threshold=threshold, @@ -519,6 +535,7 @@ def __init__( block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, ) def block_continued_generation_error( @@ -573,6 +590,8 @@ class MinTrials(TrialBasedCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -585,6 +604,7 @@ def __init__( block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: super().__init__( threshold=threshold, @@ -595,6 +615,7 @@ def __init__( block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, ) def block_continued_generation_error(