Skip to content

Commit

Permalink
unify LightningEnum (#5389)
Browse files Browse the repository at this point in the history
* unify LightningEnum

* hash

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* Update states.py

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
Borda and carmocca authored Jan 12, 2021
1 parent 54d20dc commit 51b9df3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import DistributedType, LightningEnum


class LoggerStages(str, Enum):
class LoggerStages(LightningEnum):
""" Train/validation/test phase in each training step.
>>> # you can math the type with string
Expand All @@ -42,7 +41,7 @@ def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages':
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")


class ResultStoreType(str, Enum):
class ResultStoreType(LightningEnum):
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from functools import wraps
from typing import Callable, Optional

import pytorch_lightning
from pytorch_lightning.utilities import LightningEnum


class TrainerState(str, Enum):
class TrainerState(LightningEnum):
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed.
Expand All @@ -28,7 +28,7 @@ class TrainerState(str, Enum):
True
>>> # which is case sensitive
>>> TrainerState.FINISHED == 'finished'
False
True
"""
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def __eq__(self, other: Union[str, Enum]) -> bool:
other = other.value if isinstance(other, Enum) else str(other)
return self.value.lower() == other.lower()

def __hash__(self):
# re-enable hashtable so it can be used as a dict key or in a set
# example: set(LightningEnum)
return hash(self.name)


class AMPType(LightningEnum):
"""Type of Automatic Mixed Precission used for training.
Expand Down

0 comments on commit 51b9df3

Please sign in to comment.