Skip to content

Commit

Permalink
fix woq_tune
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 committed Jun 28, 2024
1 parent 04c17ea commit bd4f99f
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions onnx_neural_compressor/quantization/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,15 +492,26 @@ def autotune(
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
tmp_folder = tempfile.TemporaryDirectory()
pathlib.Path(tmp_folder.name).joinpath("./eval").mkdir()
if optimization_level != ort.GraphOptimizationLevel.ORT_DISABLE_ALL:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = optimization_level
sess_options.optimized_model_filepath = pathlib.Path(tmp_folder.name).joinpath("opt.onnx").as_posix()
sess_options.optimized_model_filepath = pathlib.Path(tmp_folder.name).joinpath("model.onnx").as_posix()
sess_options.add_session_config_entry(
"session.optimized_model_external_initializers_file_name", "opt.onnx_data"
"session.optimized_model_external_initializers_file_name", "model.onnx_data"
)
sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "1024")
session = ort.InferenceSession(model_input, sess_options)

# copy config.json to tmp dir for evaluation, LLMs evaluation may need it
if isinstance(model_input, str) and os.path.exists(
pathlib.Path(model_input).parent.joinpath("config.json").as_posix()
):
shutil.copyfile(
pathlib.Path(model_input).parent.joinpath("config.json").as_posix(),
pathlib.Path(tmp_folder.name).joinpath("config.json").as_posix(),
)

model_input = sess_options.optimized_model_filepath
del session

Expand All @@ -513,6 +524,7 @@ def autotune(
logger.warning("Please pass model path to autotune API rather than onnx.ModelProto.")
print(traceback.format_exc())
exit(0)

tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
Expand All @@ -534,10 +546,9 @@ def autotune(
# evaluate API requires str input
onnx.save_model(
q_model,
pathlib.Path(tmp_folder.name).joinpath("eval.onnx").as_posix(),
pathlib.Path(tmp_folder.name).joinpath("./eval/model.onnx").as_posix(),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="eval.onnx_data",
size_threshold=1024,
convert_attribute=False,
)
Expand All @@ -547,15 +558,19 @@ def autotune(
):
shutil.copyfile(
pathlib.Path(model_input).parent.joinpath("config.json").as_posix(),
pathlib.Path(tmp_folder.name).joinpath("config.json").as_posix(),
pathlib.Path(tmp_folder.name).joinpath("./eval/config.json").as_posix(),
)
eval_result: float = eval_func_wrapper.evaluate(pathlib.Path(tmp_folder.name).joinpath("eval.onnx").as_posix())
eval_result: float = eval_func_wrapper.evaluate(
pathlib.Path(tmp_folder.name).joinpath("./eval/model.onnx").as_posix()
)
tuning_logger.evaluation_end()
logger.info("Evaluation result: %.4f", eval_result)
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
external_data_helper.load_external_data_for_model(q_model, tmp_folder.name)
external_data_helper.load_external_data_for_model(
q_model, pathlib.Path(tmp_folder.name).joinpath("./eval").as_posix()
)
best_quant_model = q_model
break

Expand Down

0 comments on commit bd4f99f

Please sign in to comment.