Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging fixes #80

Merged
merged 12 commits into from
Apr 18, 2024
Merged
87 changes: 51 additions & 36 deletions whisper_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import librosa
from functools import lru_cache
import time
import logging


import io
import soundfile as sf
import math

logger = logging.getLogger(__name__)

@lru_cache
def load_audio(fname):
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
Expand Down Expand Up @@ -62,7 +67,7 @@ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
print("ignoring model_dir, not implemented",file=self.logfile)
logger.debug("ignoring model_dir, not implemented")
return whisper.load_model(modelsize, download_root=cache_dir)

def transcribe(self, audio, init_prompt=""):
Expand Down Expand Up @@ -101,8 +106,9 @@ class FasterWhisperASR(ASRBase):

def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
logging.getLogger("faster_whisper").setLevel(logger.level)
if model_dir is not None:
print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = modelsize
Expand Down Expand Up @@ -225,7 +231,7 @@ def transcribe(self, audio_data, prompt=None, *args, **kwargs):

# Process transcription/translation
transcript = proc.create(**params)
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")

return transcript

Expand Down Expand Up @@ -268,9 +274,11 @@ def insert(self, new, offset):
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
if c == tail:
print("removing last",i,"words:",file=self.logfile)
words = []
for j in range(i):
print("\t",self.new.pop(0),file=self.logfile)
words.append(repr(self.new.pop(0)))
words_msg = "\t".join(words)
logger.debug(f"removing last {i} words: {words_msg}")
break

def flush(self):
Expand Down Expand Up @@ -359,9 +367,9 @@ def process_iter(self):
"""

prompt, non_prompt = self.prompt()
print("PROMPT:", prompt, file=self.logfile)
print("CONTEXT:", non_prompt, file=self.logfile)
print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
logger.debug(f"PROMPT: {prompt}")
logger.debug(f"CONTEXT: {non_prompt}")
logger.debug(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-pick, and I also took shortcuts here so feel free to ignore, but formatting text in logs is better done using logger.debug("...%s...", value) because it avoids interpolation in case the log is not written because of the log level.

Similarly, if you create variables or do any processing just for logging (like in line 381 below) it is good form to wrap with if logger.isEnabledFor(somelevel):.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, agree. I'll tidy those up.

res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)

# transform to [(beg,end,"word1"), ...]
Expand All @@ -370,8 +378,10 @@ def process_iter(self):
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
completed = self.to_flush(o)
logger.debug(f">>>>COMPLETE NOW: {completed}")
the_rest = self.to_flush(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {the_rest}")

# there is a newly confirmed text

Expand All @@ -395,26 +405,26 @@ def process_iter(self):
#while k>0 and self.commited[k][1] > l:
# k -= 1
#t = self.commited[k][1]
print(f"chunking segment",file=self.logfile)
logger.debug(f"chunking segment")
#self.chunk_at(t)

print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}")
return self.to_flush(o)

def chunk_completed_sentence(self):
if self.commited == []: return
print(self.commited,file=self.logfile)
logger.debug(self.commited)
sents = self.words_to_sentences(self.commited)
for s in sents:
print("\t\tSENT:",s,file=self.logfile)
logger.debug(f"\t\tSENT: {s}")
if len(sents) < 2:
return
while len(sents) > 2:
sents.pop(0)
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]

print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
self.chunk_at(chunk_at)

def chunk_completed_segment(self, res):
Expand All @@ -431,12 +441,12 @@ def chunk_completed_segment(self, res):
ends.pop(-1)
e = ends[-2]+self.buffer_time_offset
if e <= t:
print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
logger.debug(f"--- segment chunked at {e:2.2f}")
self.chunk_at(e)
else:
print(f"--- last segment not within commited area",file=self.logfile)
logger.debug(f"--- last segment not within commited area")
else:
print(f"--- not enough segments to chunk",file=self.logfile)
logger.debug(f"--- not enough segments to chunk")



Expand Down Expand Up @@ -482,7 +492,7 @@ def finish(self):
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
print("last, noncommited:",f,file=self.logfile)
logger.debug("last, noncommited: {f}")
return f


Expand Down Expand Up @@ -522,7 +532,7 @@ def split(self, text):

# the following languages are in Whisper, but not in wtpsplit:
if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
print(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.", file=sys.stderr)
logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
lan = None

from wtpsplit import WtP
Expand All @@ -548,14 +558,15 @@ def add_shared_args(parser):
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
parser.add_argument("-l", "--log-level", dest="log_level", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the log level", default='DEBUG')

def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
"""
backend = args.backend
if backend == "openai-api":
print("Using OpenAI API.", file=logfile)
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan)
else:
if backend == "faster-whisper":
Expand All @@ -566,14 +577,14 @@ def asr_factory(args, logfile=sys.stderr):
# Only for FasterWhisperASR and WhisperTimestampedASR
size = args.model
t = time.time()
print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True)
logger.debug(f"Loading Whisper {size} model for {args.lan}...")
asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.", file=logfile)
logger.debug(f"done. It took {round(e-t,2)} seconds.")

# Apply common configurations
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
print("Setting VAD filter", file=logfile)
logger.info("Setting VAD filter")
asr.use_vad()

language = args.lan
Expand Down Expand Up @@ -611,14 +622,18 @@ def asr_factory(args, logfile=sys.stderr):
logfile = sys.stderr

if args.offline and args.comp_unaware:
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think command line errors go to stderr explicitly. If you set the log level too high the command may confusingly fail without error message. (I think the file argument in the original was also a mistake.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, yes.

sys.exit(1)

if args.log_level:
logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
level=getattr(logging, args.log_level))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should always call basicConfig. Because of the default, log_level is set anyway.


audio_path = args.audio_path

SAMPLING_RATE = 16000
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
logger.info("Audio duration is: %2.2f seconds" % duration)

asr, online = asr_factory(args, logfile=logfile)
min_chunk = args.min_chunk_size
Expand All @@ -645,16 +660,16 @@ def output_transcript(o, now=None):
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
else:
print(o,file=logfile,flush=True)
# No text, so no output
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a functional change, but could be a good one. If o[0] is the emission time, the "no text" comment seems incorrect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think? Is the emission time on a "no text" segment useful? Now you've said it I'm inclined to bring it back in, because otherwise the consumer has no clue that any silence has been processed.


if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
pass
except AssertionError as e:
log.error(f"assertion error: {repr(e)}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never seen these errors, would a full backtrace be useful here? (If so, log.exception would be better.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were making changes here I'd be inclined to go the other way and say that a failed assertion should just crash the process. assert should be telling us that something's coded wrong, so to a certain degree all bets are off. That being said, given that I don't know what the original problem was that caused these exception handlers to be written, logging.exception is a good compromise if that seems overly harsh.

else:
output_transcript(o)
now = None
Expand All @@ -665,13 +680,13 @@ def output_transcript(o, now=None):
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
pass
else:
output_transcript(o, now=end)

print(f"## last processed {end:.2f}s",file=logfile,flush=True)
logger.debug(f"## last processed {end:.2f}s")

if end >= duration:
break
Expand All @@ -697,13 +712,13 @@ def output_transcript(o, now=None):

try:
o = online.process_iter()
except AssertionError:
print("assertion error",file=logfile)
except AssertionError as e:
logger.error(f"assertion error: {e}")
pass
else:
output_transcript(o)
now = time.time() - start
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
logger.debug(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")

if end >= duration:
break
Expand Down
42 changes: 15 additions & 27 deletions whisper_online_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import sys
import argparse
import os
import logging
import numpy as np

logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()

# server options
Expand All @@ -18,6 +21,9 @@
add_shared_args(parser)
args = parser.parse_args()

if args.log_level:
logging.basicConfig(format='whisper-server-%(levelname)s:%(name)s: %(message)s',
level=getattr(logging, args.log_level))

# setting whisper object by args

Expand All @@ -28,35 +34,25 @@
asr, online = asr_factory(args)
min_chunk = args.min_chunk_size


if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))

# warm up the ASR because the very first transcribe takes more time than the others.
# Test results in https://github.com/ufal/whisper_streaming/pull/81
msg = "Whisper is not warmed up. The first chunk processing may take longer."
if args.warmup_file:
if os.path.isfile(args.warmup_file):
a = load_audio_chunk(args.warmup_file,0,1)
asr.transcribe(a)
print("INFO: Whisper is warmed up.",file=sys.stderr)
logger.info("Whisper is warmed up.")
else:
print("WARNING: The warm up file is not available. "+msg,file=sys.stderr)
logger.warning("The warm up file is not available. "+msg)
else:
print("WARNING: " + msg, file=sys.stderr)
logger.warning(msg)


######### Server objects

import line_packet
import socket

import logging


class Connection:
'''it wraps conn object'''
PACKET_SIZE = 65536
Expand Down Expand Up @@ -104,8 +100,6 @@ def receive_audio_chunk(self):
out = []
while sum(len(x) for x in out) < self.min_chunk*SAMPLING_RATE:
raw_bytes = self.connection.non_blocking_receive_audio()
print(raw_bytes[:10])
print(len(raw_bytes))
if not raw_bytes:
break
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
Expand Down Expand Up @@ -136,7 +130,7 @@ def format_output_transcript(self,o):
print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
return "%1.0f %1.0f %s" % (beg,end,o[2])
else:
print(o,file=sys.stderr,flush=True)
logger.debug("No text in this segment")
return None

def send_result(self, o):
Expand All @@ -150,39 +144,33 @@ def process(self):
while True:
a = self.receive_audio_chunk()
if a is None:
print("break here",file=sys.stderr)
break
self.online_asr_proc.insert_audio_chunk(a)
o = online.process_iter()
try:
self.send_result(o)
except BrokenPipeError:
print("broken pipe -- connection closed?",file=sys.stderr)
logger.info("broken pipe -- connection closed?")
break

# o = online.finish() # this should be working
# self.send_result(o)




# Start logging.
level = logging.INFO
logging.basicConfig(level=level, format='whisper-server-%(levelname)s: %(message)s')

# server loop

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((args.host, args.port))
s.listen(1)
logging.info('INFO: Listening on'+str((args.host, args.port)))
logger.info('Listening on'+str((args.host, args.port)))
while True:
conn, addr = s.accept()
logging.info('INFO: Connected to client on {}'.format(addr))
logger.info('Connected to client on {}'.format(addr))
connection = Connection(conn)
proc = ServerProcessor(connection, online, min_chunk)
proc.process()
conn.close()
logging.info('INFO: Connection to client closed')
logging.info('INFO: Connection closed, terminating.')
logger.info('Connection to client closed')
logger.info('Connection closed, terminating.')