Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize Dirs #205

Merged
merged 2 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mmrazor/core/__init__.py

This file was deleted.

15 changes: 12 additions & 3 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hook import DumpSubnetHook
from .optim import SeparateOptimWrapperConstructor
from .hooks import DumpSubnetHook
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
GreedySamplerTrainLoop, SingleTeacherDistillValLoop,
SlimmableValLoop)

__all__ = ['SeparateOptimWrapperConstructor', 'DumpSubnetHook']
__all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
]
File renamed without changes.
4 changes: 0 additions & 4 deletions mmrazor/engine/optim/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.subnet import Candidates, FlopsEstimator, export_fix_subnet
from mmrazor.models.subnet.random_subnet import (MULTI_MUTATORS_RANDOM_SUBNET,
SINGLE_MUTATOR_RANDOM_SUBNET)
from mmrazor.registry import LOOPS
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import crossover

random_subnet_type = Union[SINGLE_MUTATOR_RANDOM_SUBNET,
MULTI_MUTATORS_RANDOM_SUBNET]


@LOOPS.register_module()
class EvolutionSearchLoop(EpochBasedTrainLoop):
Expand Down Expand Up @@ -215,14 +211,14 @@ def gen_crossover_candidates(self) -> List:
crossover_candidates.append(crossover_candidate)
return crossover_candidates

def _mutation(self) -> random_subnet_type:
def _mutation(self) -> SupportRandomSubnet:
"""Mutate with the specified mutate_prob."""
candidate1 = random.choice(self.top_k_candidates.subnets)
candidate2 = self.model.sample_subnet()
candidate = crossover(candidate1, candidate2, prob=self.mutate_prob)
return candidate

def _crossover(self) -> random_subnet_type:
def _crossover(self) -> SupportRandomSubnet:
"""Crossover."""
candidate1 = random.choice(self.top_k_candidates.subnets)
candidate2 = random.choice(self.top_k_candidates.subnets)
Expand Down Expand Up @@ -292,7 +288,7 @@ def _save_searcher_ckpt(self) -> None:
if osp.isfile(ckpt_path):
os.remove(ckpt_path)

def _check_constraints(self, random_subnet: random_subnet_type) -> bool:
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
"""Check whether is beyond constraints.

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.subnet import (MULTI_MUTATORS_RANDOM_SUBNET,
SINGLE_MUTATOR_RANDOM_SUBNET, Candidates,
FlopsEstimator, export_fix_subnet)
from mmrazor.registry import LOOPS

random_subnet_type = Union[SINGLE_MUTATOR_RANDOM_SUBNET,
MULTI_MUTATORS_RANDOM_SUBNET]
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet


class BaseSamplerTrainLoop(IterBasedTrainLoop):
Expand Down Expand Up @@ -48,7 +44,7 @@ def __init__(self,
self.model = runner.model

@abstractmethod
def sample_subnet(self) -> random_subnet_type:
def sample_subnet(self) -> SupportRandomSubnet:
"""Sample a subnet to train the supernet."""

def run_iter(self, data_batch: Sequence[dict]) -> None:
Expand Down Expand Up @@ -197,7 +193,7 @@ def run(self) -> None:
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')

def sample_subnet(self) -> random_subnet_type:
def sample_subnet(self) -> SupportRandomSubnet:
"""Sample a subnet from top_k candidates one by one, then to train the
surpernet with the subnet.

Expand Down Expand Up @@ -298,18 +294,18 @@ def _val_candidate(self) -> Dict:
metrics = self.evaluator.evaluate(len(self.dataloader_val.dataset))
return metrics

def _sample_from_supernet(self) -> random_subnet_type:
def _sample_from_supernet(self) -> SupportRandomSubnet:
"""Sample from the supernet."""
subnet = self.model.sample_subnet()
return subnet

def _sample_from_candidates(self) -> random_subnet_type:
def _sample_from_candidates(self) -> SupportRandomSubnet:
"""Sample from the candidates."""
assert len(self.candidates) > 0
subnet = random.choice(self.candidates)
return subnet

def _check_constraints(self, random_subnet: random_subnet_type) -> bool:
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
"""Check whether is beyond constraints.

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import numpy as np

from mmrazor.models.subnet import SINGLE_MUTATOR_RANDOM_SUBNET
from mmrazor.utils import SingleMutatorRandomSubnet


def crossover(random_subnet1: SINGLE_MUTATOR_RANDOM_SUBNET,
random_subnet2: SINGLE_MUTATOR_RANDOM_SUBNET,
prob: float = 0.5) -> SINGLE_MUTATOR_RANDOM_SUBNET:
def crossover(random_subnet1: SingleMutatorRandomSubnet,
random_subnet2: SingleMutatorRandomSubnet,
prob: float = 0.5) -> SingleMutatorRandomSubnet:
"""Crossover in genetic algorithm.

Args:
Expand Down
1 change: 0 additions & 1 deletion mmrazor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@
from .mutables import * # noqa: F401,F403
from .mutators import * # noqa: F401,F403
from .ops import * # noqa: F401,F403
from .subnet import * # noqa: F401,F403
6 changes: 3 additions & 3 deletions mmrazor/models/algorithms/nas/autoslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from mmrazor.models.distillers import ConfigurableDistiller
from mmrazor.models.mutators import OneShotChannelMutator
from mmrazor.models.subnet import SINGLE_MUTATOR_RANDOM_SUBNET
from mmrazor.models.utils import (add_prefix,
reinitialize_optim_wrapper_count_status)
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from mmrazor.utils import SingleMutatorRandomSubnet
from ..base import BaseAlgorithm

VALID_MUTATOR_TYPE = Union[OneShotChannelMutator, Dict]
Expand Down Expand Up @@ -70,10 +70,10 @@ def _build_distiller(

return distiller

def sample_subnet(self) -> SINGLE_MUTATOR_RANDOM_SUBNET:
def sample_subnet(self) -> SingleMutatorRandomSubnet:
return self.mutator.sample_choices()

def set_subnet(self, subnet: SINGLE_MUTATOR_RANDOM_SUBNET) -> None:
def set_subnet(self, subnet: SingleMutatorRandomSubnet) -> None:
self.mutator.set_choices(subnet)

def set_max_subnet(self) -> None:
Expand Down
11 changes: 8 additions & 3 deletions mmrazor/models/algorithms/nas/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.mutators import DiffModuleMutator
from mmrazor.models.subnet import (FIX_MUTABLE, export_fix_subnet,
load_fix_subnet)
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from mmrazor.utils import FixMutable
from ..base import BaseAlgorithm


Expand Down Expand Up @@ -47,7 +46,7 @@ class Darts(BaseAlgorithm):
def __init__(self,
architecture: Union[BaseModel, Dict],
mutator: Optional[Union[DiffModuleMutator, Dict]] = None,
fix_subnet: Optional[FIX_MUTABLE] = None,
fix_subnet: Optional[FixMutable] = None,
unroll: bool = False,
norm_training: bool = False,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
Expand All @@ -57,6 +56,9 @@ def __init__(self,
# Darts has two training mode: supernet training and subnet retraining.
# fix_subnet is not None, means subnet retraining.
if fix_subnet:
# Avoid circular import
from mmrazor.structures import load_fix_subnet

# According to fix_subnet, delete the unchosen part of supernet
load_fix_subnet(self.architecture, fix_subnet)
self.is_supernet = False
Expand Down Expand Up @@ -86,6 +88,9 @@ def __init__(self,
def search_subnet(self):
"""Search subnet by mutator."""

# Avoid circular import
from mmrazor.structures import export_fix_subnet

subnet = self.mutator.sample_choices()
self.mutator.set_choices(subnet)
return export_fix_subnet(self)
Expand Down
12 changes: 7 additions & 5 deletions mmrazor/models/algorithms/nas/spos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.mutators import OneShotModuleMutator
from mmrazor.models.subnet import (SINGLE_MUTATOR_RANDOM_SUBNET,
VALID_FIX_MUTABLE_TYPE, load_fix_subnet)
from mmrazor.registry import MODELS
from mmrazor.utils import SingleMutatorRandomSubnet, ValidFixMutable
from ..base import BaseAlgorithm, LossResults


Expand Down Expand Up @@ -68,7 +67,7 @@ class SPOS(BaseAlgorithm):
def __init__(self,
architecture: Union[BaseModel, Dict],
mutator: Optional[Union[OneShotModuleMutator, Dict]] = None,
fix_subnet: Optional[VALID_FIX_MUTABLE_TYPE] = None,
fix_subnet: Optional[ValidFixMutable] = None,
norm_training: bool = False,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
Expand All @@ -77,6 +76,9 @@ def __init__(self,
# SPOS has two training mode: supernet training and subnet retraining.
# fix_subnet is not None, means subnet retraining.
if fix_subnet:
# Avoid circular import
from mmrazor.structures import load_fix_subnet

# According to fix_subnet, delete the unchosen part of supernet
load_fix_subnet(self.architecture, fix_subnet)
self.is_supernet = False
Expand All @@ -101,11 +103,11 @@ def __init__(self,

self.norm_training = norm_training

def sample_subnet(self) -> SINGLE_MUTATOR_RANDOM_SUBNET:
def sample_subnet(self) -> SingleMutatorRandomSubnet:
"""Random sample subnet by mutator."""
return self.mutator.sample_choices()

def set_subnet(self, subnet: SINGLE_MUTATOR_RANDOM_SUBNET):
def set_subnet(self, subnet: SingleMutatorRandomSubnet):
"""Set the subnet sampled by :meth:sample_subnet."""
self.mutator.set_choices(subnet)

Expand Down
3 changes: 2 additions & 1 deletion mmrazor/models/algorithms/pruning/slimmable_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch import nn

from mmrazor.models.mutators import SlimmableChannelMutator
from mmrazor.models.subnet import load_fix_subnet
from mmrazor.models.utils import (add_prefix,
reinitialize_optim_wrapper_count_status)
from mmrazor.registry import MODEL_WRAPPERS, MODELS
Expand Down Expand Up @@ -69,6 +68,8 @@ def __init__(self,

# must after `prepare_from_supernet`
if len(channel_cfg_paths) == 1:
# Avoid circular import
from mmrazor.structures import load_fix_subnet
load_fix_subnet(self.architecture, channel_cfg_paths[0])
self.is_deployed = True
else:
Expand Down
2 changes: 1 addition & 1 deletion mmrazor/models/distillers/configurable_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from mmengine.model import BaseModel
from torch import nn

from mmrazor.core import DistillDeliveryManager, RecorderManager
from mmrazor.registry import MODELS
from mmrazor.structures import DistillDeliveryManager, RecorderManager
from ..algorithms.base import LossResults
from .base_distiller import BaseDistiller

Expand Down
2 changes: 1 addition & 1 deletion mmrazor/models/mutators/channel_mutator/channel_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from torch.nn import Module

from mmrazor.core.tracer import ConcatNode, DepthWiseConvNode, PathList
from mmrazor.registry import MODELS, TASK_UTILS
from mmrazor.structures import ConcatNode, DepthWiseConvNode, PathList
from ...mutables import MutableChannel
from ..base_mutator import BaseMutator
from ..utils import DEFAULT_MODULE_CONVERTERS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.core.tracer import PathList
from mmrazor.models.architectures.dynamic_op import DynamicBatchNorm
from mmrazor.models.mutables import SlimmableMutableChannel
from mmrazor.registry import MODELS
from mmrazor.structures import PathList
from ..utils import switchable_bn_converter
from .channel_mutator import ChannelMutator

Expand Down
Loading