Skip to content

Commit

Permalink
update nas (PaddlePaddle#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored Dec 31, 2019
1 parent cbaac40 commit fdb09f0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
60 changes: 43 additions & 17 deletions paddleslim/common/sa_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""The controller used to search hyperparameters or neural architecture"""

import os
import sys
import copy
import math
import logging
Expand All @@ -34,27 +35,29 @@ def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_try_times=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None):
checkpoints=None,
searched=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_try_times(int): max try times before get legal tokens.
max_try_times(int): max try times before get legal tokens. Default: 300.
init_tokens(list<int>): The initial tokens. Default: None.
reward(float): The reward of current tokens. Default: -1.
max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1.
iters(int): The iteration of sa controller. Default: 0.
best_tokens(list<int>): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched.
"""
super(SAController, self).__init__()
self._range_table = range_table
Expand All @@ -70,6 +73,7 @@ def __init__(self,
self._best_tokens = best_tokens
self._iter = iters
self._checkpoints = checkpoints
self._searched = searched if searched != None else dict()

def __getstate__(self):
d = {}
Expand All @@ -78,6 +82,18 @@ def __getstate__(self):
d[key] = self.__dict__[key]
return d

@property
def best_tokens(self):
return self._best_tokens

@property
def max_reward(self):
return self._max_reward

@property
def current_tokens(self):
return self._tokens

def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Expand All @@ -88,6 +104,7 @@ def update(self, tokens, reward, iter):
iter = int(iter)
if iter > self._iter:
self._iter = iter
self._searched[str(tokens)] = reward
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
Expand All @@ -112,22 +129,31 @@ def next_tokens(self, control_token=None):
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index])
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(
self._range_table[0][index], self._range_table[1][index])
for it in range(self._max_try_times):
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index])
_logger.debug("change index[{}] from {} to {}".format(
index, tokens[index], new_tokens[index]))

if self._searched.has_key(str(new_tokens)):
_logger.debug('get next tokens including searched tokens: {}'.
format(new_tokens))
continue
else:
self._searched[str(new_tokens)] = -1
break

if it == self._max_try_times - 1:
_logger.info(
"cannot get a effective search space which is not searched in max try times!!!"
)
sys.exit()

if self._constrain_func is None or self._max_try_times is None:
return new_tokens

return new_tokens

def _save_checkpoint(self, output_dir):
Expand Down
23 changes: 19 additions & 4 deletions paddleslim/nas/sa_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,32 @@ def __init__(self,
premax_reward = scene['_max_reward']
prebest_tokens = scene['_best_tokens']
preiter = scene['_iter']
psearched = screen['_searched']
else:
preinit_tokens = init_tokens
prereward = -1
premax_reward = -1
prebest_tokens = None
preiter = 0
psearched = None

controller = SAController(
self._controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
max_try_times=500,
init_tokens=preinit_tokens,
reward=prereward,
max_reward=premax_reward,
iters=preiter,
best_tokens=prebest_tokens,
constrain_func=None,
checkpoints=save_checkpoint)
checkpoints=save_checkpoint,
searched = psearched)

max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
controller=self._controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
Expand All @@ -137,6 +140,18 @@ def _get_host_ip(self):
def tokens2arch(self, tokens):
return self._search_space.token2arch(tokens)

def current_info(self):
"""
Get current information, including best tokens, best reward in all the search, and current token.
Returns:
dict<name, value>: a dictionary include best tokens, best reward and current reward.
"""
current_dict = dict()
current_dict['best_tokens'] = self._controller.best_tokens
current_dict['best_reward'] = self._controller.max_reward
current_dict['current_tokens'] = self._controller.current_tokens
return current_dict

def next_archs(self):
"""
Get next network architectures.
Expand Down

0 comments on commit fdb09f0

Please sign in to comment.