# -*- coding:utf-8 -*-

import DeepSpeech
import tensorflow as tf

import sys
import re

import time

import os

from util import text as text_utils

import const

current_dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(current_dir_path, "data")


initialized = False

def init(n_hidden=const.DEEP_SPEECH_N_HIDDEN, checkpoint_dir=const.DEEP_SPEECH_CHECKPOINT_DIR, alphabet_config_path=const.DEEP_SPEECH_ALPHABET_PATH, use_lm=False, language_tool_language=''):
    global initialized

    if initialized:
        return

    sys.argv.append("--alphabet_config_path")
    sys.argv.append(alphabet_config_path)
    sys.argv.append("--n_hidden")
    sys.argv.append(str(n_hidden))
    sys.argv.append("--checkpoint_dir")
    sys.argv.append(checkpoint_dir)
    sys.argv.append("--infer_use_lm="+("1" if use_lm else "0"))
    sys.argv.append("--lt_lang="+language_tool_language)    


    DeepSpeech.initialize_globals()

    initialized = True

def init_session():
    print('Use Language Model: %s' % str(DeepSpeech.FLAGS.infer_use_lm))

    session = tf.Session(config=DeepSpeech.session_config)

    inputs, outputs = DeepSpeech.create_inference_graph(batch_size=1, use_new_decoder=DeepSpeech.FLAGS.infer_use_lm)
    # Create a saver using variables from the above newly created graph
    saver = tf.train.Saver(tf.global_variables())
    # Restore variables from training checkpoint
    # TODO: This restores the most recent checkpoint, but if we use validation to counterract
    #       over-fitting, we may want to restore an earlier checkpoint.
    checkpoint = tf.train.get_checkpoint_state(DeepSpeech.FLAGS.checkpoint_dir)
    if not checkpoint:
        print('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(DeepSpeech.FLAGS.checkpoint_dir))
        sys.exit(1)

    checkpoint_path = checkpoint.model_checkpoint_path
    saver.restore(session, checkpoint_path)

    return session, inputs, outputs


def infer(wav_path, session_tuple):
    session, inputs, outputs = session_tuple    

    start_time = time.time()
    mfcc = DeepSpeech.audiofile_to_input_vector(wav_path, DeepSpeech.n_input, DeepSpeech.n_context)
    

    start_time = time.time()
    output = session.run(outputs['outputs'], feed_dict={
        inputs['input']: [mfcc],
        inputs['input_lengths']: [len(mfcc)],
    })
    #print "INFER took %.2f" % (time.time() - start_time)

    text = DeepSpeech.ndarray_to_text(output[0][0], DeepSpeech.alphabet)

    return text

if __name__ == "__main__":

    start_time = time.time()
    init(use_lm=True)
    print("DeepSpeech init took %.2f sec" % (time.time() - start_time))

    

    start_time = time.time()
    session = init_session()
    print("session init took %.2f sec" % (time.time() - start_time))

    test_file_path = os.path.join(const.DATA_DIR, "infer_test_3.wav")

    for i in range(0, 10):
        start_time = time.time()
        print infer(test_file_path, session)
        print("infer took %.2f sec" % (time.time() - start_time))