Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes some typing issues #3418

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions composer/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
self.rows.extend(rows)

def eval_end(self, state: State, logger: Logger) -> None:
# eval_batch_end will have set these
assert self.columns is not None
assert self.name is not None

list_of_rows = dist.all_gather_object(self.rows)
rows = [row for rows in list_of_rows for row in rows]
for dest_logger in logger.destinations:
Expand Down
10 changes: 5 additions & 5 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Evaluator:

When specifying ``eval_interval``, the evaluator(s) are also run at the ``Event.FIT_END`` if it doesn't
evenly divide the training duration.
device_eval_microbatch_size (int, optional): The number of samples to use for each microbatch when evaluating.
device_eval_microbatch_size (str | int | float, optional): The number of samples to use for each microbatch when evaluating.
If set to ``auto``, dynamically decreases device_eval_microbatch_size if microbatch is too large for GPU.
If None, sets `device_eval_microbatch_size` to per rank batch size. (default: ``None``)
"""
Expand All @@ -80,7 +80,7 @@ def __init__(
metric_names: Optional[list[str]] = None,
subset_num_batches: Optional[int] = None,
eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]] = None,
device_eval_microbatch_size: Optional[Union[int, str]] = None,
device_eval_microbatch_size: Optional[Union[int, str, float]] = None,
):
self.label = label
self.dataloader = ensure_data_spec(dataloader)
Expand Down Expand Up @@ -142,7 +142,7 @@ def ensure_evaluator(evaluator: Union[Evaluator, DataSpec, Iterable, dict[str, A
)


def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str]]):
def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str, float]]):
if device_eval_microbatch_size == 'auto':
warnings.warn((
"Setting `device_eval_microbatch_size='auto'` is an experimental feature which may cause "
Expand All @@ -155,10 +155,10 @@ def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str]


def _get_initial_device_eval_microbatch_size(
device_eval_microbatch_size: Optional[Union[int, str]],
device_eval_microbatch_size: Optional[Union[int, str, float]],
auto_microbatching: bool,
dataloader: Iterable,
) -> int:
) -> Union[int, float]:
"""Sets initial value of device_eval_microbatch_size.

If auto_microbatching, sets initial `device_eval_microbatch_size` to per rank batch size.
Expand Down
4 changes: 3 additions & 1 deletion composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(
def _start_mlflow_run(self, state):
import mlflow

# This function is only called if self._enabled is True, and therefore self._experiment_id is not None.
assert self._experiment_id is not None

env_run_id = os.getenv(
mlflow.environment_variables.MLFLOW_RUN_ID.name, # pyright: ignore[reportGeneralTypeIssues]
None,
Expand All @@ -193,7 +196,6 @@ def _start_mlflow_run(self, state):
self._run_id = env_run_id
elif self.resume:
# Search for an existing run tagged with this Composer run if `self.resume=True`.
assert self._experiment_id is not None
run_name = self.tags['run_name']
existing_runs = mlflow.search_runs(
experiment_ids=[self._experiment_id],
Expand Down
Loading