Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More EpochResultStore refactors! 🎉 #5522

Merged
merged 12 commits into from
Feb 11, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -148,48 +148,31 @@ def get_epoch_log_metrics(self, *_, **__) -> List[Dict]:
def get_forked_metrics(self, *_, **__) -> List[Dict]:
return self.get_epoch_from_func_name("get_forked_metrics")

@staticmethod
def _append_to_structure(primary_dict, opt_idx, batch_idx, result) -> None:
primary_dict.setdefault(opt_idx, {})
primary_dict[opt_idx].setdefault(batch_idx, [])
primary_dict[opt_idx][batch_idx].append(result)

def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optional[dict] = None) -> None:
assert isinstance(result, Result)
if dataloader_idx is None:
dataloader_idx = 0

if extra_info is None:
extra_info = {}

# [dataloader_idx][optimizer_idx][training_step_idx] is a list
if len(extra_info) > 0:
self._internal_type = ResultStoreType.INSIDE_BATCH_TRAIN_LOOP
# initialize dictionary
def append(self, result: Result, info: Dict) -> None:
dataloader_idx = info["dataloader_idx"]
self._internal_type = info["type"]
opt_idx = info["opt_idx"]
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:
# [dataloader_idx][optimizer_idx][training_step_idx] is a list
if dataloader_idx not in self._internals:
self._internals[dataloader_idx] = {}
self._internals_reduced[dataloader_idx] = defaultdict(dict)
self._latest_ref[dataloader_idx] = {}

# extract infos
opt_idx = extra_info["opt_idx"]
batch_idx = extra_info["batch_idx"]

self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)

self._latest_ref[dataloader_idx][opt_idx] = result

# [dataloader_idx] is a list
batch_idx = info["batch_idx"]
self._internals[dataloader_idx].setdefault(opt_idx, {})
self._internals[dataloader_idx][opt_idx].setdefault(batch_idx, [])
self._internals[dataloader_idx][opt_idx][batch_idx].append(result)
else:
self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
self._internals.setdefault(dataloader_idx, [])
self._internals[dataloader_idx].append(result)

if dataloader_idx not in self._latest_ref:
self._latest_ref[dataloader_idx] = {}
self._latest_ref[dataloader_idx][0] = {}
carmocca marked this conversation as resolved.
Show resolved Hide resolved

self._latest_ref[dataloader_idx][0] = result
self._latest_ref[dataloader_idx][opt_idx] = result

def auto_reduce_results_on_epoch_end(self) -> None:
"""
Expand All @@ -212,36 +195,32 @@ def auto_reduce_results_on_epoch_end(self) -> None:
for opt_idx in range(num_opt_idx + 1):
# TODO: Figure out to reduce memory
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]

outputs = epoch_metrics[opt_idx]
# reduce across time first
time_reduced_outputs = []
for batch_idx in opt_outputs.keys():
tbptt_outs = opt_outputs[batch_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
if len(tbptt_outs) > 1:
time_reduced_outputs.append(tbptt_outs)
for tbptt_outputs in outputs.values():
tbptt_outputs = type(tbptt_outputs[0]).reduce_across_time(tbptt_outputs)
if len(tbptt_outputs) > 1:
time_reduced_outputs.append(tbptt_outputs)

if len(time_reduced_outputs) == 0:
continue

# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)
outputs = type(time_reduced_outputs[0]).reduce_on_epoch_end(time_reduced_outputs)

# with manual opt need 1 + metrics because meta is always there
if opt_outputs.minimize is not None:
opt_outputs.minimize = opt_outputs.minimize.mean()
if outputs.minimize is not None:
outputs.minimize = outputs.minimize.mean()

self._internals_reduced[dl_idx][opt_idx] = opt_outputs
self._internals_reduced[dl_idx][opt_idx] = outputs

# free memory
del self._internals[dl_idx][opt_idx]
else:
# no need to reduce as called only once
if len(epoch_metrics) == 1:
reduced_epoch_metrics = epoch_metrics[0]
else:
reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics)
reduced_epoch_metrics = epoch_metrics[0]
if len(epoch_metrics) != 1:
reduced_epoch_metrics = type(reduced_epoch_metrics).reduce_on_epoch_end(epoch_metrics)

self._internals_reduced[dl_idx] = reduced_epoch_metrics

Expand Down Expand Up @@ -281,18 +260,23 @@ def __getitem__(self, key: str) -> Any:
return self._internals.get(key, None)

@property
def has_split_and_opt_idx(self):
"""
This function informs if we are running within training batch loop
"""
return self._split_idx is not None and self._opt_idx is not None

@property
def extra_info(self):
def info(self):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
This function provides necessary parameters to properly configure HookResultStore obj
"""
return {"batch_idx": self.trainer.batch_idx, "split_idx": self._split_idx, "opt_idx": self._opt_idx}
model_ref = self.trainer.get_model()
return {
"batch_idx": self.trainer.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
"dataloader_idx": model_ref._current_dataloader_idx or 0,
"opt_idx": self._opt_idx or 0,
"split_idx": self._split_idx or 0,
"type": (
ResultStoreType.INSIDE_BATCH_TRAIN_LOOP
if self._opt_idx is not None and self._split_idx is not None else
ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
)
}

def reset_model(self):
"""
Expand All @@ -303,17 +287,6 @@ def reset_model(self):
model_ref._current_hook_fx_name = None
model_ref._current_fx_name = ''

def current_model_info(self):
"""
This function is used to extract
information related to current function scoping `self.log` call.
"""
model_ref = self.trainer.get_model()
# extract hook information
fx_name = model_ref._current_hook_fx_name or model_ref._current_fx_name
dataloader_idx = model_ref._current_dataloader_idx
return fx_name, dataloader_idx

def cache_result(self) -> None:
"""
This function is called after every hook
Expand All @@ -330,13 +303,11 @@ def cache_result(self) -> None:
model_ref._current_fx_name = ''
return

# extract model information
fx_name, dataloader_idx = self.current_model_info()
info = self.info
fx_name = info["fx_name"]

self._internals.setdefault(fx_name, HookResultStore(fx_name))

extra_info = self.extra_info if self.has_split_and_opt_idx else {}

# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)

Expand All @@ -346,16 +317,15 @@ def cache_result(self) -> None:
elif self.trainer._distrib_type == DistributedType.DP:
hook_result.to(torch.device("cuda", self.trainer.root_gpu))

self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)
self._internals[fx_name].append(hook_result, info)

# update logged_metrics, progress_bar_metrics, callback_metrics

if "epoch_end" in fx_name:
self.update_logger_connector()

self.reset_model()

def update_logger_connector(self) -> None:
def update_logger_connector(self) -> Tuple[Dict, Dict]:
"""
This function is called every time we capture a hook
It automatically updates the logger_connector followings:
Expand Down Expand Up @@ -507,24 +477,24 @@ def __call__(

Example::

result: Result = self(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)
result: Result = self(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)
result['train_loss_epoch'] # aggregated train_loss over one epoch.

Args:

fx_name: Hook name from ModelHooks or Callback. Example: `training_step`
fx_name: Hook name from ModelHooks or Callback. Example: ``"training_step"``

dl_idx: Dataloader idx in short. It starts from 0 to num_dataloaders - 1
dl_idx: Dataloader index in short. From ``0`` to ``num_dataloaders - 1``

opt_idx: Optimizer idx in short. It starts from 0 to num_optimizers - 1
opt_idx: Optimizer index in short. From ``0`` to ``num_optimizers - 1``

batch_idx: Index of batch idx seen during batch training or evaluation.
Works only with reduced=False
batch_idx: Batch index seen during batch training or evaluation.
Works only with ``reduced=False``

split_idx: Index of split idx in training loop when ttbt is used.

reduced: Data are being aggregated on on_epoch_end.
Indicates if we want to access aggregated Result or not.
Indicates if we want to access the aggregated Result or not.
"""
hook_result = self[fx_name]
internal_type = hook_result._internal_type
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -48,6 +47,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand All @@ -57,7 +57,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -903,7 +903,7 @@ def call_hook(self, hook_name, *args, **kwargs):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)

# if the PL module doesn't have the hook then call the accelator
# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
Expand Down
117 changes: 0 additions & 117 deletions tests/trainer/dynamic_args/test_multiple_optimizers.py

This file was deleted.

Loading