-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initializeで全モデルを読み込まなくても良いようにした #124
The head ref may contain hidden characters: "\u30E2\u30C7\u30EB\u8AAD\u307F\u8FBC\u307F\u3092lazy\u306B\u3059\u308B"
Changes from all commits
210b6e1
eafac33
a0c3539
619e403
49c8e6f
deb77c9
9daf2e4
dcfa19f
662147c
aeb7d4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
#include <array> | ||
#include <exception> | ||
#include <memory> | ||
#include <optional> | ||
#include <string> | ||
#include <unordered_set> | ||
|
||
|
@@ -19,9 +20,7 @@ | |
#include "core.h" | ||
|
||
#define NOT_INITIALIZED_ERR "Call initialize() first." | ||
#define NOT_FOUND_ERR "No such file or directory: " | ||
#define FAILED_TO_OPEN_MODEL_ERR "Unable to open model files." | ||
#define FAILED_TO_OPEN_METAS_ERR "Unable to open metas.json." | ||
#define NOT_LOADED_ERR "Model is not loaded." | ||
#define ONNX_ERR "ONNX raise exception: " | ||
#define JSON_ERR "JSON parser raise exception: " | ||
#define GPU_NOT_SUPPORTED_ERR "This library is CPU version. GPU is not supported." | ||
|
@@ -43,13 +42,19 @@ EMBED_DECL(YUKARIN_S); | |
EMBED_DECL(YUKARIN_SA); | ||
EMBED_DECL(DECODE); | ||
|
||
const struct { | ||
/** | ||
* 3種類のモデルを一纏めにしたもの | ||
*/ | ||
struct VVMODEL { | ||
embed::EMBED_RES (*YUKARIN_S)(); | ||
embed::EMBED_RES (*YUKARIN_SA)(); | ||
embed::EMBED_RES (*DECODE)(); | ||
} MODELS_LIST[] = {{YUKARIN_S, YUKARIN_SA, DECODE}}; | ||
}; | ||
const VVMODEL VVMODEL_LIST[] = { | ||
{YUKARIN_S, YUKARIN_SA, DECODE}, | ||
}; | ||
} // namespace EMBED_DECL_NAMESPACE | ||
using EMBED_DECL_NAMESPACE::MODELS_LIST; | ||
using EMBED_DECL_NAMESPACE::VVMODEL_LIST; | ||
Comment on lines
+45
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 今までは |
||
|
||
// 複数モデルある場合のspeaker_idマッピング | ||
// {元のspeaker_id: {モデル番号, 新しいspeaker_id}} | ||
|
@@ -76,8 +81,23 @@ SupportedDevices get_supported_devices() { | |
} | ||
|
||
struct Status { | ||
Status(bool use_gpu_) | ||
: use_gpu(use_gpu_), memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)) {} | ||
Status(int model_count, bool use_gpu, int cpu_num_threads) | ||
: memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)) { | ||
yukarin_s_list = std::vector<std::optional<Ort::Session>>(model_count); | ||
yukarin_sa_list = std::vector<std::optional<Ort::Session>>(model_count); | ||
decode_list = std::vector<std::optional<Ort::Session>>(model_count); | ||
|
||
session_options.SetInterOpNumThreads(cpu_num_threads).SetIntraOpNumThreads(cpu_num_threads); | ||
if (use_gpu) { | ||
#ifdef DIRECTML | ||
session_options.DisableMemPattern().SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); | ||
#else | ||
const OrtCUDAProviderOptions cuda_options; | ||
session_options.AppendExecutionProvider_CUDA(cuda_options); | ||
#endif | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use_gpuはSTATUSの初期化時に、cpu_num_threadsはモデルロード時に指定だったのを、どちらも初期化時に指定する形式に変更しました。
Comment on lines
+84
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C++ではコンストラクタ内で例外が発生するとclass及びstructのfieldで確保されているリソースが開放されないためメモリリーク等の問題が起きます。そのためC++でのコンストラクタでは基本的にfieldの初期化のみ程度に留めるべきです。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. なるほどです。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. コンストラク内の例外でメモリリークが発生してしまうのはデストラクタが呼ばれないからで、例外処理をきちんと書いていればメモリリークは起きません。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
はい。正確にはデストラクタが呼ばれないですね。
今回のケースだとリソースは各classのデストラクタで管理されているのでコンストラクタ外で処理を書けば例外が発生した場合各classのデストラクタが呼び出されるのでメモリリークは発生しないはずですね。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. すみませんコンストラクタ周りの仕様についてちょっと勘違いしていたらしく、そもそもこの実装のままでも各フィールドのデストラクタは呼ばれそうです...?
このままにするとのことだったのであまり意味のない情報ですが |
||
/** | ||
* Loads the metas.json. | ||
* | ||
|
@@ -89,7 +109,7 @@ struct Status { | |
* version: string | ||
* }] | ||
*/ | ||
bool load(int cpu_num_threads) { | ||
bool load_metas() { | ||
embed::Resource metas_file = METAS(); | ||
|
||
metas = nlohmann::json::parse(metas_file.data, metas_file.data + metas_file.size); | ||
|
@@ -100,36 +120,32 @@ struct Status { | |
supported_styles.insert(style["id"].get<int64_t>()); | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
for (const auto MODELS : MODELS_LIST) { | ||
embed::Resource yukarin_s_model = MODELS.YUKARIN_S(); | ||
embed::Resource yukarin_sa_model = MODELS.YUKARIN_SA(); | ||
embed::Resource decode_model = MODELS.DECODE(); | ||
|
||
Ort::SessionOptions session_options; | ||
session_options.SetInterOpNumThreads(cpu_num_threads).SetIntraOpNumThreads(cpu_num_threads); | ||
yukarin_s_list.push_back(Ort::Session(env, yukarin_s_model.data, yukarin_s_model.size, session_options)); | ||
yukarin_sa_list.push_back(Ort::Session(env, yukarin_sa_model.data, yukarin_sa_model.size, session_options)); | ||
if (use_gpu) { | ||
#ifdef DIRECTML | ||
session_options.DisableMemPattern().SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); | ||
#else | ||
const OrtCUDAProviderOptions cuda_options; | ||
session_options.AppendExecutionProvider_CUDA(cuda_options); | ||
#endif | ||
} | ||
decode_list.push_back(Ort::Session(env, decode_model.data, decode_model.size, session_options)); | ||
} | ||
/** | ||
* モデルを読み込む | ||
*/ | ||
bool load_model(int model_index) { | ||
const auto VVMODEL = VVMODEL_LIST[model_index]; | ||
embed::Resource yukarin_s_model = VVMODEL.YUKARIN_S(); | ||
embed::Resource yukarin_sa_model = VVMODEL.YUKARIN_SA(); | ||
embed::Resource decode_model = VVMODEL.DECODE(); | ||
|
||
yukarin_s_list[model_index] = | ||
std::move(Ort::Session(env, yukarin_s_model.data, yukarin_s_model.size, session_options)); | ||
yukarin_sa_list[model_index] = | ||
std::move(Ort::Session(env, yukarin_sa_model.data, yukarin_sa_model.size, session_options)); | ||
decode_list[model_index] = std::move(Ort::Session(env, decode_model.data, decode_model.size, session_options)); | ||
return true; | ||
} | ||
|
||
std::string root_dir_path; | ||
bool use_gpu; | ||
Ort::SessionOptions session_options; | ||
Ort::MemoryInfo memory_info; | ||
|
||
Ort::Env env{ORT_LOGGING_LEVEL_ERROR}; | ||
std::vector<Ort::Session> yukarin_s_list, yukarin_sa_list, decode_list; | ||
std::vector<std::optional<Ort::Session>> yukarin_s_list, yukarin_sa_list, decode_list; | ||
|
||
nlohmann::json metas; | ||
std::string metas_str; | ||
|
@@ -166,7 +182,7 @@ std::pair<int64_t, int64_t> get_model_index_and_speaker_id(int64_t speaker_id) { | |
return found->second; | ||
} | ||
|
||
bool initialize(bool use_gpu, int cpu_num_threads) { | ||
bool initialize(bool use_gpu, int cpu_num_threads, bool load_all_models) { | ||
initialized = false; | ||
|
||
#ifdef DIRECTML | ||
|
@@ -178,18 +194,29 @@ bool initialize(bool use_gpu, int cpu_num_threads) { | |
return false; | ||
} | ||
try { | ||
status = std::make_unique<Status>(use_gpu); | ||
if (!status->load(cpu_num_threads)) { | ||
const int model_count = std::size(VVMODEL_LIST); | ||
status = std::make_unique<Status>(model_count, use_gpu, cpu_num_threads); | ||
if (!status->load_metas()) { | ||
return false; | ||
} | ||
if (use_gpu) { | ||
// 一回走らせて十分なGPUメモリを確保させる | ||
int length = 500; | ||
int phoneme_size = 45; | ||
std::vector<float> phoneme(length * phoneme_size), f0(length); | ||
int64_t speaker_id = 0; | ||
std::vector<float> output(length * 256); | ||
decode_forward(length, phoneme_size, f0.data(), phoneme.data(), &speaker_id, output.data()); | ||
|
||
if (load_all_models) { | ||
for (int model_index = 0; model_index < model_count; model_index++) { | ||
if (!status->load_model(model_index)) { | ||
return false; | ||
} | ||
} | ||
|
||
if (use_gpu) { | ||
// 一回走らせて十分なGPUメモリを確保させる | ||
// TODO: 全MODELに対して行う | ||
int length = 500; | ||
int phoneme_size = 45; | ||
std::vector<float> phoneme(length * phoneme_size), f0(length); | ||
int64_t speaker_id = 0; | ||
std::vector<float> output(length * 256); | ||
decode_forward(length, phoneme_size, f0.data(), phoneme.data(), &speaker_id, output.data()); | ||
} | ||
} | ||
} catch (const Ort::Exception &e) { | ||
error_message = ONNX_ERR; | ||
|
@@ -208,6 +235,17 @@ bool initialize(bool use_gpu, int cpu_num_threads) { | |
return true; | ||
} | ||
|
||
bool load_model(int64_t speaker_id) { | ||
auto [model_index, _] = get_model_index_and_speaker_id(speaker_id); | ||
return status->load_model(model_index); | ||
} | ||
|
||
bool is_model_loaded(int64_t speaker_id) { | ||
auto [model_index, _] = get_model_index_and_speaker_id(speaker_id); | ||
return (status->yukarin_s_list[model_index].has_value() && status->yukarin_sa_list[model_index].has_value() && | ||
status->decode_list[model_index].has_value()); | ||
} | ||
|
||
void finalize() { | ||
initialized = false; | ||
status.reset(); | ||
|
@@ -231,6 +269,11 @@ bool yukarin_s_forward(int64_t length, int64_t *phoneme_list, int64_t *speaker_i | |
return false; | ||
} | ||
auto [model_index, model_speaker_id] = get_model_index_and_speaker_id(*speaker_id); | ||
auto &model = status->yukarin_s_list[model_index]; | ||
if (!model) { | ||
error_message = NOT_LOADED_ERR; | ||
return false; | ||
} | ||
try { | ||
const char *inputs[] = {"phoneme_list", "speaker_id"}; | ||
const char *outputs[] = {"phoneme_length"}; | ||
|
@@ -240,8 +283,8 @@ bool yukarin_s_forward(int64_t length, int64_t *phoneme_list, int64_t *speaker_i | |
to_tensor(&model_speaker_id, speaker_shape)}; | ||
Ort::Value output_tensor = to_tensor(output, phoneme_shape); | ||
|
||
status->yukarin_s_list[model_index].Run(Ort::RunOptions{nullptr}, inputs, input_tensors.data(), | ||
input_tensors.size(), outputs, &output_tensor, 1); | ||
model.value().Run(Ort::RunOptions{nullptr}, inputs, input_tensors.data(), input_tensors.size(), outputs, | ||
&output_tensor, 1); | ||
|
||
for (int64_t i = 0; i < length; i++) { | ||
if (output[i] < PHONEME_LENGTH_MINIMAL) output[i] = PHONEME_LENGTH_MINIMAL; | ||
|
@@ -266,6 +309,11 @@ bool yukarin_sa_forward(int64_t length, int64_t *vowel_phoneme_list, int64_t *co | |
return false; | ||
} | ||
auto [model_index, model_speaker_id] = get_model_index_and_speaker_id(*speaker_id); | ||
auto &model = status->yukarin_sa_list[model_index]; | ||
if (!model) { | ||
error_message = NOT_LOADED_ERR; | ||
return false; | ||
} | ||
try { | ||
const char *inputs[] = { | ||
"length", "vowel_phoneme_list", "consonant_phoneme_list", "start_accent_list", | ||
|
@@ -283,8 +331,8 @@ bool yukarin_sa_forward(int64_t length, int64_t *vowel_phoneme_list, int64_t *co | |
to_tensor(&model_speaker_id, speaker_shape)}; | ||
Ort::Value output_tensor = to_tensor(output, phoneme_shape); | ||
|
||
status->yukarin_sa_list[model_index].Run(Ort::RunOptions{nullptr}, inputs, input_tensors.data(), | ||
input_tensors.size(), outputs, &output_tensor, 1); | ||
model.value().Run(Ort::RunOptions{nullptr}, inputs, input_tensors.data(), input_tensors.size(), outputs, | ||
&output_tensor, 1); | ||
} catch (const Ort::Exception &e) { | ||
error_message = ONNX_ERR; | ||
error_message += e.what(); | ||
|
@@ -346,6 +394,11 @@ bool decode_forward(int64_t length, int64_t phoneme_size, float *f0, float *phon | |
return false; | ||
} | ||
auto [model_index, model_speaker_id] = get_model_index_and_speaker_id(*speaker_id); | ||
auto &model = status->decode_list[model_index]; | ||
if (!model) { | ||
error_message = NOT_LOADED_ERR; | ||
return false; | ||
} | ||
try { | ||
// 音が途切れてしまうのを避けるworkaround処理が入っている | ||
// TODO: 改善したらここのpadding処理を取り除く | ||
|
@@ -381,8 +434,8 @@ bool decode_forward(int64_t length, int64_t phoneme_size, float *f0, float *phon | |
const char *inputs[] = {"f0", "phoneme", "speaker_id"}; | ||
const char *outputs[] = {"wave"}; | ||
|
||
status->decode_list[model_index].Run(Ort::RunOptions{nullptr}, inputs, input_tensor.data(), input_tensor.size(), | ||
outputs, &output_tensor, 1); | ||
model.value().Run(Ort::RunOptions{nullptr}, inputs, input_tensor.data(), input_tensor.size(), outputs, | ||
&output_tensor, 1); | ||
|
||
// TODO: 改善したらここのcopy処理を取り除く | ||
copy_output_with_padding_to_output(output_with_padding, output, padding_f0_size); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
消したのは未使用のエラーメッセージ