Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Fix a few issues in Retiarii #3725

Merged
merged 5 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nni/retiarii/nn/pytorch/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _replicate_and_instantiate(blocks, repeat):

class Cell(nn.Module):
"""
Cell structure [1]_ [2]_ that is popularly used in NAS literature.
Cell structure [zophnas]_ [zophnasnet]_ that is popularly used in NAS literature.

A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
Expand All @@ -95,8 +95,8 @@ class Cell(nn.Module):

References
----------
.. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
.. [zophnas] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [zophnasnet] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""

Expand Down
11 changes: 6 additions & 5 deletions nni/retiarii/nn/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class LayerChoiceMutator(Mutator):
def __init__(self, nodes: List[Node]):
super().__init__()
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes

def mutate(self, model):
Expand All @@ -40,7 +40,7 @@ def mutate(self, model):

class InputChoiceMutator(Mutator):
def __init__(self, nodes: List[Node]):
super().__init__()
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes

def mutate(self, model):
Expand All @@ -56,7 +56,7 @@ def mutate(self, model):

class ValueChoiceMutator(Mutator):
def __init__(self, nodes: List[Node], candidates: List[Any]):
super().__init__()
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
self.candidates = candidates

Expand All @@ -69,7 +69,8 @@ def mutate(self, model):

class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
super().__init__()
node, argname = nodes[0]
super().__init__(label=node.operation.parameters[argname].label)
self.nodes = nodes
self.candidates = candidates

Expand All @@ -84,7 +85,7 @@ def mutate(self, model):
class RepeatMutator(Mutator):
def __init__(self, nodes: List[Node]):
# nodes is a subgraph consisting of repeated blocks.
super().__init__()
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes

def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
Expand Down
5 changes: 4 additions & 1 deletion nni/retiarii/strategy/_rl_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file might cause import error for those who didn't install RL-related dependencies

import logging
import threading
from multiprocessing.pool import ThreadPool

import gym
Expand All @@ -18,6 +19,7 @@


_logger = logging.getLogger(__name__)
_thread_lock = threading.Lock()


class MultiThreadEnvWorker(EnvWorker):
Expand Down Expand Up @@ -100,7 +102,8 @@ def step(self, action):
if self.cur_step < self.num_steps else self.action_dim
}
if self.cur_step == self.num_steps:
model = get_targeted_model(self.base_model, self.mutators, self.sample)
with _thread_lock:
model = get_targeted_model(self.base_model, self.mutators, self.sample)
_logger.info(f'New model created: {self.sample}')
submit_models(model)
wait_models(model)
Expand Down
4 changes: 2 additions & 2 deletions nni/retiarii/strategy/bruteforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
while query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))

Expand Down Expand Up @@ -113,6 +113,6 @@ def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
while query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
12 changes: 12 additions & 0 deletions nni/retiarii/strategy/tpe_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def choice(self, candidates, mutator, model, index):


class TPEStrategy(BaseStrategy):
"""
The Tree-structured Parzen Estimator (TPE) [bergstrahpo]_ is a sequential model-based optimization (SMBO) approach.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.

References
----------

.. [bergstrahpo] Bergstra et al., "Algorithms for Hyper-Parameter Optimization".
https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf
"""

def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
Expand Down