Skip to content

Commit

Permalink
Add count_only_trials_with_data to trial-based TCs (facebook#3012)
Browse files Browse the repository at this point in the history
Summary:

For some GenerationStrategies we only want to progress if the trials in the expected statuses also actually have data, this extends trial based criterion to consider if we have data or not

Differential Revision: D64617987
  • Loading branch information
mgarrard authored and facebook-github-bot committed Nov 6, 2024
1 parent 93c236e commit c0ac309
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ def test_node_string_representation(self) -> None:
"'not_in_statuses': None, 'transition_to': None, "
"'block_transition_if_unmet': True, 'block_gen_if_met': False, "
"'use_all_trials_in_exp': False, "
"'continue_trial_generation': False})])"
"'continue_trial_generation': False, "
"'count_only_trials_with_data': False})])"
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
untested_constructors.remove(constructor)

# There should be no untested constructors left.
print(untested_constructors)
self.assertEqual(len(untested_constructors), 0)

def test_consume_all_n_constructor(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def test_repr(self) -> None:
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_trials_in_status_criterion = MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
Expand All @@ -599,7 +600,8 @@ def test_repr(self) -> None:
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_preference_occurrences_criterion = MinimumPreferenceOccurances(
metric_name="m1", threshold=3
Expand Down
23 changes: 22 additions & 1 deletion ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


class TransitionCriterion(SortableBase, SerializationMixin):
# TODO: @mgarrard rename to ActionCriterion
"""
Simple class to describe a condition which must be met for this GenerationNode to
take an action such as generation, transition, etc.
Expand Down Expand Up @@ -262,6 +261,8 @@ class TrialBasedCriterion(TransitionCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -274,11 +275,13 @@ def __init__(
transition_to: str | None = None,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
self.threshold = threshold
self.only_in_statuses = only_in_statuses
self.not_in_statuses = not_in_statuses
self.use_all_trials_in_exp = use_all_trials_in_exp
self.count_only_trials_with_data = count_only_trials_with_data
super().__init__(
transition_to=transition_to,
block_transition_if_unmet=block_transition_if_unmet,
Expand Down Expand Up @@ -334,6 +337,14 @@ def num_contributing_to_threshold(
trials_from_node: The set of trials generated by this GenerationNode.
"""
all_trials_to_check = self.all_trials_to_check(experiment=experiment)
if self.count_only_trials_with_data:
all_trials_to_check = {
trial_index
for trial_index in all_trials_to_check
# TODO[@mgarrard]: determine if we need to actually check data with
# more granularity, e.g. number of days of data, etc.
if trial_index in experiment.data_by_trial
}
# Some criteria may rely on experiment level data, instead of only trials
# generated from the node associated with the criterion.
if self.use_all_trials_in_exp:
Expand Down Expand Up @@ -415,6 +426,8 @@ class MaxGenerationParallelism(TrialBasedCriterion):
create a given ``BatchTrial``. Defaults to False for
MaxGenerationParallelism since this criterion isn't currently used for
node -> node or trial -> trial transition.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand Down Expand Up @@ -497,6 +510,8 @@ class MaxTrials(TrialBasedCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -509,6 +524,7 @@ def __init__(
block_gen_if_met: bool | None = False,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
super().__init__(
threshold=threshold,
Expand All @@ -519,6 +535,7 @@ def __init__(
block_transition_if_unmet=block_transition_if_unmet,
use_all_trials_in_exp=use_all_trials_in_exp,
continue_trial_generation=continue_trial_generation,
count_only_trials_with_data=count_only_trials_with_data,
)

def block_continued_generation_error(
Expand Down Expand Up @@ -573,6 +590,8 @@ class MinTrials(TrialBasedCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -585,6 +604,7 @@ def __init__(
block_gen_if_met: bool | None = False,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
super().__init__(
threshold=threshold,
Expand All @@ -595,6 +615,7 @@ def __init__(
block_transition_if_unmet=block_transition_if_unmet,
use_all_trials_in_exp=use_all_trials_in_exp,
continue_trial_generation=continue_trial_generation,
count_only_trials_with_data=count_only_trials_with_data,
)

def block_continued_generation_error(
Expand Down

0 comments on commit c0ac309

Please sign in to comment.