diff --git a/nni/retiarii/execution/base.py b/nni/retiarii/execution/base.py index ec8b5e4e6d..e1d464c9ce 100644 --- a/nni/retiarii/execution/base.py +++ b/nni/retiarii/execution/base.py @@ -8,27 +8,39 @@ from typing import Any, Dict, Iterable, List from .interface import AbstractExecutionEngine, AbstractGraphListener +from .utils import get_mutation_summary from .. import codegen, utils from ..graph import Model, ModelStatus, MetricData, Evaluator from ..integration_api import send_trial, receive_trial_parameters, get_advisor _logger = logging.getLogger(__name__) - class BaseGraphData: - def __init__(self, model_script: str, evaluator: Evaluator) -> None: + """ + Attributes + ---------- + model_script + code of an instantiated PyTorch model + evaluator + training approach for model_script + mutation_summary + a dict of all the choices during mutations in the HPO search space format + """ + def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None: self.model_script = model_script self.evaluator = evaluator + self.mutation_summary = mutation_summary def dump(self) -> dict: return { 'model_script': self.model_script, - 'evaluator': self.evaluator + 'evaluator': self.evaluator, + 'mutation_summary': self.mutation_summary } @staticmethod def load(data) -> 'BaseGraphData': - return BaseGraphData(data['model_script'], data['evaluator']) + return BaseGraphData(data['model_script'], data['evaluator'], data['mutation_summary']) class BaseExecutionEngine(AbstractExecutionEngine): @@ -111,7 +123,8 @@ def budget_exhausted(self) -> bool: @classmethod def pack_model_data(cls, model: Model) -> Any: - return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + mutation_summary = get_mutation_summary(model) + return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary) @classmethod def trial_execute_graph(cls) -> None: diff --git a/nni/retiarii/execution/benchmark.py b/nni/retiarii/execution/benchmark.py index 7b325d4b6b..a3e6ac4c3f 100644 --- a/nni/retiarii/execution/benchmark.py +++ b/nni/retiarii/execution/benchmark.py @@ -5,7 +5,7 @@ from ..graph import Model from ..integration_api import receive_trial_parameters from .base import BaseExecutionEngine -from .python import get_mutation_dict +from .utils import get_mutation_dict class BenchmarkGraphData: diff --git a/nni/retiarii/execution/cgo_engine.py b/nni/retiarii/execution/cgo_engine.py index 9482f9b493..e6e204dbe7 100644 --- a/nni/retiarii/execution/cgo_engine.py +++ b/nni/retiarii/execution/cgo_engine.py @@ -156,7 +156,7 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None: phy_models_and_placements = self._assemble(logical) for model, placement, grouped_models in phy_models_and_placements: - data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator) + data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {}) placement_constraint = self._extract_placement_constaint(placement) trial_id = send_trial(data.dump(), placement_constraint=placement_constraint) # unique non-cpu devices used by the trial diff --git a/nni/retiarii/execution/python.py b/nni/retiarii/execution/python.py index a13b96ac98..e33f198de3 100644 --- a/nni/retiarii/execution/python.py +++ b/nni/retiarii/execution/python.py @@ -1,9 +1,10 @@ -from typing import Dict, Any, List +from typing import Dict, Any from ..graph import Evaluator, Model from ..integration_api import receive_trial_parameters from ..utils import ContextStack, import_, get_importable_name from .base import BaseExecutionEngine +from .utils import get_mutation_dict, mutation_dict_to_summary class PythonGraphData: @@ -13,13 +14,15 @@ def __init__(self, class_name: str, init_parameters: Dict[str, Any], self.init_parameters = init_parameters self.mutation = mutation self.evaluator = evaluator + self.mutation_summary = mutation_dict_to_summary(mutation) def dump(self) -> dict: return { 'class_name': self.class_name, 'init_parameters': self.init_parameters, 'mutation': self.mutation, - 'evaluator': self.evaluator + 'evaluator': self.evaluator, + 'mutation_summary': self.mutation_summary } @staticmethod @@ -55,13 +58,3 @@ def __init__(self): with ContextStack('fixed', graph_data.mutation): graph_data.evaluator._execute(_model) - - -def _unpack_if_only_one(ele: List[Any]): - if len(ele) == 1: - return ele[0] - return ele - - -def get_mutation_dict(model: Model): - return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history} diff --git a/nni/retiarii/execution/utils.py b/nni/retiarii/execution/utils.py new file mode 100644 index 0000000000..db9efe85cd --- /dev/null +++ b/nni/retiarii/execution/utils.py @@ -0,0 +1,25 @@ +from typing import Any, List +from ..graph import Model + +def _unpack_if_only_one(ele: List[Any]): + if len(ele) == 1: + return ele[0] + return ele + +def get_mutation_dict(model: Model): + return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history} + +def mutation_dict_to_summary(mutation: dict) -> dict: + mutation_summary = {} + for label, samples in mutation.items(): + # FIXME: this check might be wrong + if not isinstance(samples, list): + mutation_summary[label] = samples + else: + for i, sample in enumerate(samples): + mutation_summary[f'{label}_{i}'] = sample + return mutation_summary + +def get_mutation_summary(model: Model) -> dict: + mutation = get_mutation_dict(model) + return mutation_dict_to_summary(mutation) diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index dc9472b748..191bddc851 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -28,13 +28,14 @@ from ..converter import convert_to_graph from ..converter.graph_gen import GraphConverterWithShape from ..execution import list_models, set_execution_engine -from ..execution.python import get_mutation_dict +from ..execution.utils import get_mutation_dict from ..graph import Evaluator from ..integration import RetiariiAdvisor from ..mutator import Mutator from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation from ..oneshot.interface import BaseOneShotTrainer from ..strategy import BaseStrategy +from ..strategy.utils import dry_run_for_formatted_search_space _logger = logging.getLogger(__name__) @@ -193,6 +194,8 @@ def _start_strategy(self): ) _logger.info('Start strategy...') + search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators) + self.update_search_space(search_space) self.strategy.run(base_model_ir, self.applied_mutators) _logger.info('Strategy exit') # TODO: find out a proper way to show no more trial message on WebUI diff --git a/nni/retiarii/integration_api.py b/nni/retiarii/integration_api.py index 68737a2561..cceff57cab 100644 --- a/nni/retiarii/integration_api.py +++ b/nni/retiarii/integration_api.py @@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int: """ return get_advisor().send_trial(parameters, placement_constraint) - def receive_trial_parameters() -> dict: """ Received a new trial. Executed on trial end. diff --git a/nni/retiarii/strategy/local_debug_strategy.py b/nni/retiarii/strategy/local_debug_strategy.py index 743d6b2fc6..dd842babcf 100644 --- a/nni/retiarii/strategy/local_debug_strategy.py +++ b/nni/retiarii/strategy/local_debug_strategy.py @@ -8,6 +8,7 @@ from .. import Sampler, codegen, utils from ..execution.base import BaseGraphData +from ..execution.utils import get_mutation_summary from .base import BaseStrategy _logger = logging.getLogger(__name__) @@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy): """ def run_one_model(self, model): - graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + mutation_summary = get_mutation_summary(model) + graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary) random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) file_name = f'_generated_model/{random_str}.py' os.makedirs(os.path.dirname(file_name), exist_ok=True) diff --git a/nni/retiarii/strategy/utils.py b/nni/retiarii/strategy/utils.py index 87d9d24037..4262674f86 100644 --- a/nni/retiarii/strategy/utils.py +++ b/nni/retiarii/strategy/utils.py @@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, search_space[(mutator, i)] = candidates return search_space +def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]: + search_space = collections.OrderedDict() + for mutator in mutators: + recorded_candidates, model = mutator.dry_run(model) + if len(recorded_candidates) == 1: + search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]} + else: + for i, candidate in enumerate(recorded_candidates): + search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate} + return search_space def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model: sampler = _FixedSampler(sample) diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index 2ab5da1269..a24c517dc0 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -8,7 +8,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit from nni.retiarii.converter import convert_to_graph from nni.retiarii.codegen import model_to_pytorch_script -from nni.retiarii.execution.python import _unpack_if_only_one +from nni.retiarii.execution.utils import _unpack_if_only_one from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module from nni.retiarii.serializer import model_wrapper from nni.retiarii.utils import ContextStack diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index 2639e18505..67453f8ca6 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -513,8 +513,9 @@ class NNIManager implements Manager { if (this.dispatcher === undefined) { throw new Error('Error: tuner has not been setup'); } + this.log.info(`Updated search space ${searchSpace}`); this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, searchSpace); - this.experimentProfile.params.searchSpace = searchSpace; + this.experimentProfile.params.searchSpace = JSON.parse(searchSpace); return; } diff --git a/ts/webui/src/static/interface.ts b/ts/webui/src/static/interface.ts index 561fddbbb2..0ead649f29 100644 --- a/ts/webui/src/static/interface.ts +++ b/ts/webui/src/static/interface.ts @@ -228,6 +228,10 @@ interface SearchItems { isChoice: boolean; // for parameters: type = choice and status also as choice type } +interface RetiariiParameter { + mutation_summary: object; // retiarii experiment's parameter +} + export { TableObj, TableRecord, @@ -253,5 +257,6 @@ export { SortInfo, AllExperimentList, Tensorboard, - SearchItems + SearchItems, + RetiariiParameter }; diff --git a/ts/webui/src/static/model/trial.ts b/ts/webui/src/static/model/trial.ts index 2ecce34e0a..02db184317 100644 --- a/ts/webui/src/static/model/trial.ts +++ b/ts/webui/src/static/model/trial.ts @@ -7,7 +7,8 @@ import { Parameters, FinalType, MultipleAxes, - SingleAxis + SingleAxis, + RetiariiParameter } from '../interface'; import { getFinal, @@ -31,9 +32,11 @@ function inferTrialParameters( space: MultipleAxes, prefix: string = '' ): [Map, Map] { + const latestedParamObj = + 'mutation_summary' in paramObj ? (paramObj as RetiariiParameter).mutation_summary : paramObj; const parameters = new Map(); const unexpectedEntries = new Map(); - for (const [k, v] of Object.entries(paramObj)) { + for (const [k, v] of Object.entries(latestedParamObj)) { // prefix can be a good fallback when corresponding item is not found in namespace const axisKey = space.axes.get(k); if (prefix && k === '_name') continue;