Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[retiarii] support visualize model space with the hpo chart on webui #4304

Merged
merged 25 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,39 @@
from typing import Any, Dict, Iterable, List

from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor

_logger = logging.getLogger(__name__)


class BaseGraphData:
def __init__(self, model_script: str, evaluator: Evaluator) -> None:
"""
Parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributesParameters

----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
self.model_script = model_script
self.evaluator = evaluator
self.mutation_summary = mutation_summary
Copy link
Contributor

@cruiseliu cruiseliu Dec 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a new term so it needs doc/comment to describe what is "mutation summary".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


def dump(self) -> dict:
return {
'model_script': self.model_script,
'evaluator': self.evaluator
'evaluator': self.evaluator,
'mutation_summary': self.mutation_summary
}

@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['evaluator'])
return BaseGraphData(data['model_script'], data['evaluator'], data['mutation_summary'])


class BaseExecutionEngine(AbstractExecutionEngine):
Expand Down Expand Up @@ -111,7 +123,8 @@ def budget_exhausted(self) -> bool:

@classmethod
def pack_model_data(cls, model: Model) -> Any:
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
mutation_summary = get_mutation_summary(model)
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)

@classmethod
def trial_execute_graph(cls) -> None:
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/execution/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..graph import Model
from ..integration_api import receive_trial_parameters
from .base import BaseExecutionEngine
from .python import get_mutation_dict
from .utils import get_mutation_dict


class BenchmarkGraphData:
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/execution/cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None:

phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator)
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
Expand Down
17 changes: 5 additions & 12 deletions nni/retiarii/execution/python.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, Any, List
from typing import Dict, Any

from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name
from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary


class PythonGraphData:
Expand All @@ -13,13 +14,15 @@ def __init__(self, class_name: str, init_parameters: Dict[str, Any],
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
self.mutation_summary = mutation_dict_to_summary(mutation)

def dump(self) -> dict:
return {
'class_name': self.class_name,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
'evaluator': self.evaluator
'evaluator': self.evaluator,
'mutation_summary': self.mutation_summary
}

@staticmethod
Expand Down Expand Up @@ -55,13 +58,3 @@ def __init__(self):

with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)


def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele


def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
25 changes: 25 additions & 0 deletions nni/retiarii/execution/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, List
from ..graph import Model

def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele

def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}

def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
for label, samples in mutation.items():
# FIXME: this check might be wrong
if not isinstance(samples, list):
mutation_summary[label] = samples
else:
for i, sample in enumerate(samples):
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary

def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)
5 changes: 4 additions & 1 deletion nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..execution.utils import get_mutation_dict
from ..graph import Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation
from ..oneshot.interface import BaseOneShotTrainer
from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,6 +194,8 @@ def _start_strategy(self):
)

_logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
Expand Down
1 change: 0 additions & 1 deletion nni/retiarii/integration_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int:
"""
return get_advisor().send_trial(parameters, placement_constraint)


def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Expand Down
4 changes: 3 additions & 1 deletion nni/retiarii/strategy/local_debug_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .. import Sampler, codegen, utils
from ..execution.base import BaseGraphData
from ..execution.utils import get_mutation_summary
from .base import BaseStrategy

_logger = logging.getLogger(__name__)
Expand All @@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy):
"""

def run_one_model(self, model):
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
mutation_summary = get_mutation_summary(model)
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
Expand Down
10 changes: 10 additions & 0 deletions nni/retiarii/strategy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any,
search_space[(mutator, i)] = candidates
return search_space

def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
if len(recorded_candidates) == 1:
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
else:
for i, candidate in enumerate(recorded_candidates):
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
return search_space

def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
Expand Down
2 changes: 1 addition & 1 deletion test/ut/retiarii/test_highlevel_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution.python import _unpack_if_only_one
from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack
Expand Down
3 changes: 2 additions & 1 deletion ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.log.info(`Updated search space ${searchSpace}`);
this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, searchSpace);
this.experimentProfile.params.searchSpace = searchSpace;
this.experimentProfile.params.searchSpace = JSON.parse(searchSpace);

return;
}
Expand Down
7 changes: 6 additions & 1 deletion ts/webui/src/static/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ interface SearchItems {
isChoice: boolean; // for parameters: type = choice and status also as choice type
}

interface RetiariiParameter {
mutation_summary: object; // retiarii experiment's parameter
}

export {
TableObj,
TableRecord,
Expand All @@ -253,5 +257,6 @@ export {
SortInfo,
AllExperimentList,
Tensorboard,
SearchItems
SearchItems,
RetiariiParameter
};
7 changes: 5 additions & 2 deletions ts/webui/src/static/model/trial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import {
Parameters,
FinalType,
MultipleAxes,
SingleAxis
SingleAxis,
RetiariiParameter
} from '../interface';
import {
getFinal,
Expand All @@ -31,9 +32,11 @@ function inferTrialParameters(
space: MultipleAxes,
prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
const latestedParamObj =
'mutation_summary' in paramObj ? (paramObj as RetiariiParameter).mutation_summary : paramObj;
const parameters = new Map<SingleAxis, any>();
const unexpectedEntries = new Map<string, any>();
for (const [k, v] of Object.entries(paramObj)) {
for (const [k, v] of Object.entries(latestedParamObj)) {
// prefix can be a good fallback when corresponding item is not found in namespace
const axisKey = space.axes.get(k);
if (prefix && k === '_name') continue;
Expand Down