Skip to content

Commit

Permalink
Change implementation to use add_instantiator.
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Aug 22, 2023
1 parent 8388f88 commit ae8aa65
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 35 deletions.
67 changes: 35 additions & 32 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from functools import partial, update_wrapper
Expand Down Expand Up @@ -51,6 +52,8 @@
locals()["ArgumentParser"] = object
locals()["Namespace"] = object

ModuleType = TypeVar("ModuleType")


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -198,30 +201,6 @@ def add_lr_scheduler_args(
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def class_instantiator(self, class_type, *args, **kwargs):
for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items():
if issubclass(class_type, base_type):
with given_hyperparameters_context(hparams):
return super().class_instantiator(class_type, *args, **kwargs)
return super().class_instantiator(class_type, *args, **kwargs)

def instantiate_classes(
self,
cfg: Namespace,
instantiate_groups: bool = True,
hparam_context: Optional[Dict[str, type]] = None,
) -> Namespace:
if hparam_context:
cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets!
self._hparam_context = {}
for key, base_type in hparam_context.items():
hparams = cfg_dict.get(key, {})
self._hparam_context[key] = (base_type, hparams)
init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups)
if hparam_context:
delattr(self, "_hparam_context")
return init


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -405,6 +384,7 @@ def __init__(

self._set_seed()

self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()

Expand Down Expand Up @@ -551,18 +531,28 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
subclasses=self.subclass_mode_model,
)
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="data"),
_get_module_type(self._datamodule_class),
subclasses=self.subclass_mode_data,
)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""

def instantiate_classes(self) -> None:
"""Instantiates the classes and sets their attributes."""
hparam_prefix = ""
if "subcommand" in self.config:
hparam_prefix = self.config["subcommand"] + "."
hparam_context = {hparam_prefix + "model": self._model_class}
if self.datamodule_class is not None:
hparam_context[hparam_prefix + "data"] = self._datamodule_class
self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context)
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self._get(self.config_init, "data")
self.model = self._get(self.config_init, "model")
self._add_configure_optimizers_method_to_model(self.subcommand)
Expand Down Expand Up @@ -788,7 +778,20 @@ def _get_short_description(component: object) -> Optional[str]:
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")


ModuleType = TypeVar("ModuleType")
def _get_module_type(value: Union[Callable, type]) -> type:
if callable(value) and not isinstance(value, type):
return inspect.signature(value).return_annotation
return value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})):
return class_type(*args, **kwargs)


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argparse import Namespace
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, List, MutableMapping, Optional, Sequence, Union
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union

from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters

Expand All @@ -29,7 +29,7 @@


@contextmanager
def given_hyperparameters_context(value):
def given_hyperparameters_context(value: dict) -> Iterator[None]:
token = given_hyperparameters.set(value)
try:
yield
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _load_state(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: Optional[bool] = None,
instantiator=None,
instantiator: Optional[Callable] = None,
**cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
Expand Down

0 comments on commit ae8aa65

Please sign in to comment.