Skip to content

Commit

Permalink
色々再構成
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 8, 2023
1 parent cc84068 commit e0f29c6
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 126 deletions.
20 changes: 12 additions & 8 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::sync::Arc;
use super::full_context_label::Utterance;
use super::open_jtalk::OpenJtalk;
use super::*;
use crate::infer::{InferenceRuntime, Output};
use crate::infer::{
signatures::{Decode, PredictDuration, PredictIntonation},
InferenceRuntime, SupportsInferenceSignature,
};
use crate::numerics::F32Ext as _;
use crate::InferenceCore;

Expand All @@ -15,19 +18,20 @@ const MORA_PHONEME_LIST: &[&str] = &[
"a", "i", "u", "e", "o", "N", "A", "I", "U", "E", "O", "cl", "pau",
];

pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

#[derive(new)]
pub(crate) struct SynthesisEngine<R: InferenceRuntime> {
inference_core: InferenceCore<R>,
open_jtalk: Arc<OpenJtalk>,
}

impl<R> SynthesisEngine<R>
where
R: InferenceRuntime,
(Vec<f32>,): Output<R>,
impl<
R: SupportsInferenceSignature<PredictDuration>
+ SupportsInferenceSignature<PredictIntonation>
+ SupportsInferenceSignature<Decode>,
> SynthesisEngine<R>
{
pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

pub fn inference_core(&self) -> &InferenceCore<R> {
&self.inference_core
}
Expand Down Expand Up @@ -426,7 +430,7 @@ where
let num_channels: u16 = if output_stereo { 2 } else { 1 };
let bit_depth: u16 = 16;
let repeat_count: u32 =
(output_sampling_rate / Self::DEFAULT_SAMPLING_RATE) * num_channels as u32;
(output_sampling_rate / DEFAULT_SAMPLING_RATE) * num_channels as u32;
let block_size: u16 = bit_depth * num_channels / 8;

let bytes_size = wave.len() as u32 * repeat_count * 2;
Expand Down
108 changes: 60 additions & 48 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,70 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc};

use derive_new::new;
use enum_map::{Enum, EnumMap};
use ndarray::{Array, Dimension, LinalgScalar};
use thiserror::Error;

use crate::{ErrorRepr, SupportedDevices};

pub(crate) trait InferenceRuntime: 'static {
type Session: Session;
type RunBuilder<'a>: RunBuilder<'a, Session = Self::Session>;
type Session: InferenceSession;
type RunContext<'a>: RunContext<'a, Session = Self::Session>;
fn supported_devices() -> crate::Result<SupportedDevices>;
}

pub(crate) trait Session: Sized + Send + 'static {
pub(crate) trait InferenceSession: Sized + Send + 'static {
fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
options: InferenceSessionOptions,
) -> anyhow::Result<Self>;
}

pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> {
type Session: Session;
fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self;
pub(crate) trait RunContext<'a>: From<&'a mut Self::Session> {
type Session: InferenceSession;
}

pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::InputScalar {}
pub(crate) trait SupportsInferenceSignature<S: InferenceSignature>:
SupportsInferenceInputTensors<S::Input> + SupportsInferenceOutput<S::Output>
{
}

impl<
R: SupportsInferenceInputTensors<S::Input> + SupportsInferenceOutput<S::Output>,
S: InferenceSignature,
> SupportsInferenceSignature<S> for R
{
}

pub(crate) trait SupportsInferenceInputTensor<I>: InferenceRuntime {
fn input(ctx: &mut Self::RunContext<'_>, tensor: I);
}

impl InputScalar for i64 {}
impl InputScalar for f32 {}
pub(crate) trait SupportsInferenceInputTensors<I: InferenceInput>: InferenceRuntime {
fn input(ctx: &mut Self::RunContext<'_>, tensors: I);
}

pub(crate) trait SupportsInferenceOutput<O: Send>: InferenceRuntime {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<O>;
}

pub(crate) trait Signature: Sized + Send + 'static {
pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Kind: Enum + Copy;
type Output;
type Input: InferenceInput;
type Output: Send;
const KIND: Self::Kind;
fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>);
}

pub(crate) trait Output<R: InferenceRuntime>: Sized + Send {
fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result<Self>;
pub(crate) trait InferenceInput: Send + 'static {
type Signature: InferenceSignature;
}

pub(crate) struct SessionSet<K: Enum, R: InferenceRuntime>(
pub(crate) struct InferenceSessionSet<K: Enum, R: InferenceRuntime>(
EnumMap<K, Arc<std::sync::Mutex<R::Session>>>,
);

impl<K: Enum + Copy, R: InferenceRuntime> SessionSet<K, R> {
impl<K: Enum + Copy, R: InferenceRuntime> InferenceSessionSet<K, R> {
pub(crate) fn new(
model_bytes: &EnumMap<K, Vec<u8>>,
mut options: impl FnMut(K) -> SessionOptions,
mut options: impl FnMut(K) -> InferenceSessionOptions,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
Expand All @@ -68,52 +85,47 @@ impl<K: Enum + Copy, R: InferenceRuntime> SessionSet<K, R> {
}
}

impl<K: Enum, R: InferenceRuntime> SessionSet<K, R> {
pub(crate) fn get<S: Signature<Kind = K>>(&self) -> SessionCell<R, S> {
SessionCell {
inner: self.0[S::KIND].clone(),
impl<K: Enum, R: InferenceRuntime> InferenceSessionSet<K, R> {
pub(crate) fn get<I>(&self) -> InferenceSessionCell<R, I>
where
I: InferenceInput,
I::Signature: InferenceSignature<Kind = K>,
{
InferenceSessionCell {
inner: self.0[<I::Signature as InferenceSignature>::KIND].clone(),
marker: PhantomData,
}
}
}

pub(crate) struct SessionCell<R: InferenceRuntime, S> {
pub(crate) struct InferenceSessionCell<R: InferenceRuntime, I> {
inner: Arc<std::sync::Mutex<R::Session>>,
marker: PhantomData<fn(S)>,
marker: PhantomData<fn(I)>,
}

impl<R: InferenceRuntime, S: Signature> SessionCell<R, S> {
pub(crate) fn run(self, input: S) -> crate::Result<S::Output>
where
S::Output: Output<R>,
{
impl<
R: SupportsInferenceInputTensors<I>
+ SupportsInferenceOutput<<I::Signature as InferenceSignature>::Output>,
I: InferenceInput,
> InferenceSessionCell<R, I>
{
pub(crate) fn run(
self,
input: I,
) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let mut inner = self.inner.lock().unwrap();
let mut ctx = R::RunBuilder::from(&mut inner);
input.input(&mut ctx);
S::Output::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into())
let mut ctx = R::RunContext::from(&mut inner);
R::input(&mut ctx, input);
R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into())
}
}

#[derive(new, Clone, Copy)]
pub(crate) struct SessionOptions {
pub(crate) struct InferenceSessionOptions {
pub(crate) cpu_num_threads: u16,
pub(crate) use_gpu: bool,
}

#[derive(Error, Debug)]
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

impl InputScalar for i64 {}
impl InputScalar for f32 {}

pub(crate) trait OnnxruntimeInputScalar:
onnxruntime::TypeToTensorElementDataType
{
}

impl<T: onnxruntime::TypeToTensorElementDataType> OnnxruntimeInputScalar for T {}
}
31 changes: 19 additions & 12 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::fmt::Debug;

use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel};
use onnxruntime::{
environment::Environment, GraphOptimizationLevel, LoggingLevel, TypeToTensorElementDataType,
};

use self::assert_send::AssertSend;
use crate::{
devices::SupportedDevices,
error::ErrorRepr,
infer::{
DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session,
SessionOptions,
DecryptModelError, InferenceRuntime, InferenceSession, InferenceSessionOptions, RunContext,
SupportsInferenceInputTensor, SupportsInferenceOutput,
},
};

Expand All @@ -17,7 +21,7 @@ pub(crate) enum Onnxruntime {}

impl InferenceRuntime for Onnxruntime {
type Session = AssertSend<onnxruntime::session::Session<'static>>;
type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>;
type RunContext<'a> = OnnxruntimeInferenceBuilder<'a>;

fn supported_devices() -> crate::Result<SupportedDevices> {
let mut cuda_support = false;
Expand All @@ -42,10 +46,10 @@ impl InferenceRuntime for Onnxruntime {
}
}

impl Session for AssertSend<onnxruntime::session::Session<'static>> {
impl InferenceSession for AssertSend<onnxruntime::session::Session<'static>> {
fn new(
model: impl FnOnce() -> std::result::Result<Vec<u8>, DecryptModelError>,
options: SessionOptions,
options: InferenceSessionOptions,
) -> anyhow::Result<Self> {
let mut builder = ENVIRONMENT
.new_session_builder()?
Expand Down Expand Up @@ -106,20 +110,23 @@ impl<'sess> From<&'sess mut AssertSend<onnxruntime::session::Session<'static>>>
}
}

impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> {
impl<'sess> RunContext<'sess> for OnnxruntimeInferenceBuilder<'sess> {
type Session = AssertSend<onnxruntime::session::Session<'static>>;
}

fn input(&mut self, tensor: Array<impl InputScalar, impl Dimension + 'static>) -> &mut Self {
self.inputs
impl<A: TypeToTensorElementDataType + Debug + 'static, D: Dimension + 'static>
SupportsInferenceInputTensor<Array<A, D>> for Onnxruntime
{
fn input(ctx: &mut Self::RunContext<'_>, tensor: Array<A, D>) {
ctx.inputs
.push(Box::new(onnxruntime::session::NdArray::new(tensor)));
self
}
}

impl Output<Onnxruntime> for (Vec<f32>,) {
impl SupportsInferenceOutput<(Vec<f32>,)> for Onnxruntime {
fn run(
OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>,
) -> anyhow::Result<Self> {
) -> anyhow::Result<(Vec<f32>,)> {
let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?;

// FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く
Expand Down
Loading

0 comments on commit e0f29c6

Please sign in to comment.