Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add count_only_trials_with_data to trial-based TCs #3012

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ def test_node_string_representation(self) -> None:
"'not_in_statuses': None, 'transition_to': None, "
"'block_transition_if_unmet': True, 'block_gen_if_met': False, "
"'use_all_trials_in_exp': False, "
"'continue_trial_generation': False})])"
"'continue_trial_generation': False, "
"'count_only_trials_with_data': False})])"
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
untested_constructors.remove(constructor)

# There should be no untested constructors left.
print(untested_constructors)
self.assertEqual(len(untested_constructors), 0)

def test_consume_all_n_constructor(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def test_repr(self) -> None:
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_trials_in_status_criterion = MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
Expand All @@ -599,7 +600,8 @@ def test_repr(self) -> None:
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_preference_occurrences_criterion = MinimumPreferenceOccurances(
metric_name="m1", threshold=3
Expand Down
23 changes: 22 additions & 1 deletion ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


class TransitionCriterion(SortableBase, SerializationMixin):
# TODO: @mgarrard rename to ActionCriterion
"""
Simple class to describe a condition which must be met for this GenerationNode to
take an action such as generation, transition, etc.
Expand Down Expand Up @@ -262,6 +261,8 @@ class TrialBasedCriterion(TransitionCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -274,11 +275,13 @@ def __init__(
transition_to: str | None = None,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
self.threshold = threshold
self.only_in_statuses = only_in_statuses
self.not_in_statuses = not_in_statuses
self.use_all_trials_in_exp = use_all_trials_in_exp
self.count_only_trials_with_data = count_only_trials_with_data
super().__init__(
transition_to=transition_to,
block_transition_if_unmet=block_transition_if_unmet,
Expand Down Expand Up @@ -334,6 +337,14 @@ def num_contributing_to_threshold(
trials_from_node: The set of trials generated by this GenerationNode.
"""
all_trials_to_check = self.all_trials_to_check(experiment=experiment)
if self.count_only_trials_with_data:
all_trials_to_check = {
trial_index
for trial_index in all_trials_to_check
# TODO[@mgarrard]: determine if we need to actually check data with
# more granularity, e.g. number of days of data, etc.
if trial_index in experiment.data_by_trial
}
# Some criteria may rely on experiment level data, instead of only trials
# generated from the node associated with the criterion.
if self.use_all_trials_in_exp:
Expand Down Expand Up @@ -415,6 +426,8 @@ class MaxGenerationParallelism(TrialBasedCriterion):
create a given ``BatchTrial``. Defaults to False for
MaxGenerationParallelism since this criterion isn't currently used for
node -> node or trial -> trial transition.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand Down Expand Up @@ -497,6 +510,8 @@ class MaxTrials(TrialBasedCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -509,6 +524,7 @@ def __init__(
block_gen_if_met: bool | None = False,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
super().__init__(
threshold=threshold,
Expand All @@ -519,6 +535,7 @@ def __init__(
block_transition_if_unmet=block_transition_if_unmet,
use_all_trials_in_exp=use_all_trials_in_exp,
continue_trial_generation=continue_trial_generation,
count_only_trials_with_data=count_only_trials_with_data,
)

def block_continued_generation_error(
Expand Down Expand Up @@ -573,6 +590,8 @@ class MinTrials(TrialBasedCriterion):
batch from different ``GenerationNodes``. This flag should be set to
True for the last node in a set of ``GenerationNodes`` expected to
create a given ``BatchTrial``.
count_only_trials_with_data: If set to True, only trials with data will be
counted towards the ``threshold``. Defaults to False.
"""

def __init__(
Expand All @@ -585,6 +604,7 @@ def __init__(
block_gen_if_met: bool | None = False,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
) -> None:
super().__init__(
threshold=threshold,
Expand All @@ -595,6 +615,7 @@ def __init__(
block_transition_if_unmet=block_transition_if_unmet,
use_all_trials_in_exp=use_all_trials_in_exp,
continue_trial_generation=continue_trial_generation,
count_only_trials_with_data=count_only_trials_with_data,
)

def block_continued_generation_error(
Expand Down
Loading