diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index b682490f65b..9667f808ddc 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -314,7 +314,8 @@ def test_node_string_representation(self) -> None: "'not_in_statuses': None, 'transition_to': None, " "'block_transition_if_unmet': True, 'block_gen_if_met': False, " "'use_all_trials_in_exp': False, " - "'continue_trial_generation': False})])" + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})])" ), ) diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 732de6a1338..3c4ae6ac2e4 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -113,7 +113,6 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None: untested_constructors.remove(constructor) # There should be no untested constructors left. - print(untested_constructors) self.assertEqual(len(untested_constructors), 0) def test_consume_all_n_constructor(self) -> None: diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 78ca0e50032..3dc5e334890 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -580,7 +580,8 @@ def test_repr(self) -> None: + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'continue_trial_generation': False, " + + "'count_only_trials_with_data': False})", ) minimum_trials_in_status_criterion = MinTrials( only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], @@ -599,7 +600,8 @@ def test_repr(self) -> None: + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'continue_trial_generation': False, " + + "'count_only_trials_with_data': False})", ) minimum_preference_occurrences_criterion = MinimumPreferenceOccurances( metric_name="m1", threshold=3 diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index 7727a33fc80..19e44906fa9 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -31,7 +31,6 @@ class TransitionCriterion(SortableBase, SerializationMixin): - # TODO: @mgarrard rename to ActionCriterion """ Simple class to describe a condition which must be met for this GenerationNode to take an action such as generation, transition, etc. @@ -262,6 +261,8 @@ class TrialBasedCriterion(TransitionCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -274,11 +275,13 @@ def __init__( transition_to: str | None = None, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: self.threshold = threshold self.only_in_statuses = only_in_statuses self.not_in_statuses = not_in_statuses self.use_all_trials_in_exp = use_all_trials_in_exp + self.count_only_trials_with_data = count_only_trials_with_data super().__init__( transition_to=transition_to, block_transition_if_unmet=block_transition_if_unmet, @@ -334,6 +337,14 @@ def num_contributing_to_threshold( trials_from_node: The set of trials generated by this GenerationNode. """ all_trials_to_check = self.all_trials_to_check(experiment=experiment) + if self.count_only_trials_with_data: + all_trials_to_check = { + trial_index + for trial_index in all_trials_to_check + # TODO[@mgarrard]: determine if we need to actually check data with + # more granularity, e.g. number of days of data, etc. + if trial_index in experiment.data_by_trial + } # Some criteria may rely on experiment level data, instead of only trials # generated from the node associated with the criterion. if self.use_all_trials_in_exp: @@ -415,6 +426,8 @@ class MaxGenerationParallelism(TrialBasedCriterion): create a given ``BatchTrial``. Defaults to False for MaxGenerationParallelism since this criterion isn't currently used for node -> node or trial -> trial transition. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -497,6 +510,8 @@ class MaxTrials(TrialBasedCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -509,6 +524,7 @@ def __init__( block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: super().__init__( threshold=threshold, @@ -519,6 +535,7 @@ def __init__( block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, ) def block_continued_generation_error( @@ -573,6 +590,8 @@ class MinTrials(TrialBasedCriterion): batch from different ``GenerationNodes``. This flag should be set to True for the last node in a set of ``GenerationNodes`` expected to create a given ``BatchTrial``. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. """ def __init__( @@ -585,6 +604,7 @@ def __init__( block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, + count_only_trials_with_data: bool = False, ) -> None: super().__init__( threshold=threshold, @@ -595,6 +615,7 @@ def __init__( block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, + count_only_trials_with_data=count_only_trials_with_data, ) def block_continued_generation_error(