Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audio narration #195

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
351d87b
added sounddevice to optionally record narration
angelala3252 May 26, 2023
f19a84a
added sounddevice to optionally record narration and initial whisper …
angelala3252 May 27, 2023
e143767
updated requirements for audio narration
angelala3252 May 29, 2023
6f07b93
small changes
angelala3252 May 31, 2023
d3ef09a
fixed issue with created audio file being really slow
angelala3252 May 31, 2023
9e86193
updated to save audio data and transcribed text in database
angelala3252 May 31, 2023
87a814f
pull from main
angelala3252 May 31, 2023
ce84a1b
new alembic migration
angelala3252 May 31, 2023
5c584b2
edited audio tables
angelala3252 Jun 1, 2023
802c8a2
convert audio array to required format for whisper
angelala3252 Jun 1, 2023
aca8cdc
visualize audio info
angelala3252 Jun 1, 2023
42b1007
FLAC compression before storing
angelala3252 Jun 1, 2023
9f4c280
store word by word timestamps
angelala3252 Jun 1, 2023
20d29e1
style changes
angelala3252 Jun 2, 2023
109ffe0
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 14, 2023
8d27b4f
changed tiktoken version
angelala3252 Jun 16, 2023
d631b2d
removed unused tiktoken code
angelala3252 Jun 16, 2023
ab0805e
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 16, 2023
e30538b
alphabetic order, removed redundant dependencies
angelala3252 Jun 18, 2023
9469043
merged AudioInfo and AudioFile
angelala3252 Jun 18, 2023
47bf845
Merge remote-tracking branch 'audio/feat/audio_narration' into feat/a…
angelala3252 Jun 18, 2023
e9f2d36
move audio recording into record_audio function
angelala3252 Jun 19, 2023
9293b0b
use thread-local scoped_session
angelala3252 Jun 19, 2023
a66acbc
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 23, 2023
888d335
remove redundant requirement
angelala3252 Jun 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions alembic/versions/3d3239ae4849_add_audio_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Add audio info

Revision ID: 3d3239ae4849
Revises: 104d4a614d95
Create Date: 2023-05-31 19:36:46.269697

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '3d3239ae4849'
down_revision = '104d4a614d95'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('audio_info',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('flac_data', sa.LargeBinary(), nullable=True),
sa.Column('transcribed_text', sa.String(), nullable=True),
sa.Column('recording_timestamp', sa.Integer(), nullable=True),
sa.Column('sample_rate', sa.Integer(), nullable=True),
sa.Column('words_with_timestamps', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['recording_timestamp'], ['recording.timestamp'], name=op.f('fk_audio_info_recording_timestamp_recording')),
sa.PrimaryKeyConstraint('id', name=op.f('pk_audio_info'))
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('audio_info')
# ### end Alembic commands ###
67 changes: 51 additions & 16 deletions openadapt/crud.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from loguru import logger
import sqlalchemy as sa
import json

from openadapt.db import Session

from openadapt.models import (
ActionEvent,
Screenshot,
Recording,
WindowEvent,
PerformanceStat,
AudioInfo
)


BATCH_SIZE = 1

db = Session()
Expand All @@ -19,6 +21,7 @@
window_events = []
performance_stats = []


def _insert(event_data, table, buffer=None):
"""Insert using Core API for improved performance (no rows are returned)"""

Expand Down Expand Up @@ -74,6 +77,7 @@ def insert_window_event(recording_timestamp, event_timestamp, event_data):
}
_insert(event_data, WindowEvent, window_events)


def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
"""
Insert event performance stat into db
Expand All @@ -87,18 +91,40 @@ def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
}
_insert(event_perf_stat, PerformanceStat, performance_stats)


def get_perf_stats(recording_timestamp):
"""
return performance stats for a given recording
"""

return (
db
.query(PerformanceStat)
.filter(PerformanceStat.recording_timestamp == recording_timestamp)
.order_by(PerformanceStat.start_time)
.all()
.query(PerformanceStat)
.filter(PerformanceStat.recording_timestamp == recording_timestamp)
.order_by(PerformanceStat.start_time)
.all()
)


def insert_audio_info(
audio_data,
transcribed_text,
recording_timestamp,
sample_rate,
word_list
):
"""Insert an AudioInfo entry into the database."""
thread_local_db = Session()
audio_info = AudioInfo(
flac_data=audio_data,
transcribed_text=transcribed_text,
recording_timestamp=recording_timestamp,
sample_rate=sample_rate,
words_with_timestamps=json.dumps(word_list)
)
thread_local_db.add(audio_info)
thread_local_db.commit()


def insert_recording(recording_data):
db_obj = Recording(**recording_data)
Expand All @@ -111,29 +137,29 @@ def insert_recording(recording_data):
def get_latest_recording():
return (
db
.query(Recording)
.order_by(sa.desc(Recording.timestamp))
.limit(1)
.first()
.query(Recording)
.order_by(sa.desc(Recording.timestamp))
.limit(1)
.first()
)


def get_recording(timestamp):
return (
db
.query(Recording)
.filter(Recording.timestamp == timestamp)
.first()
.query(Recording)
.filter(Recording.timestamp == timestamp)
.first()
)


def _get(table, recording_timestamp):
return (
db
.query(table)
.filter(table.recording_timestamp == recording_timestamp)
.order_by(table.timestamp)
.all()
.query(table)
.filter(table.recording_timestamp == recording_timestamp)
.order_by(table.timestamp)
.all()
)


Expand All @@ -158,3 +184,12 @@ def get_screenshots(recording, precompute_diffs=False):

def get_window_events(recording):
return _get(WindowEvent, recording.timestamp)


def get_audio_info(recording):
return (
db
.query(AudioInfo)
.filter(AudioInfo.recording_timestamp == recording.timestamp)
.first()
)
5 changes: 3 additions & 2 deletions openadapt/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sqlalchemy as sa
from dictalchemy import DictableModel
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.schema import MetaData
from sqlalchemy.ext.declarative import declarative_base

Expand Down Expand Up @@ -50,4 +50,5 @@ def get_base(engine):

engine = get_engine()
Base = get_base(engine)
Session = sessionmaker(bind=engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
24 changes: 19 additions & 5 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Recording(db.Base):
order_by="WindowEvent.timestamp",
)

audio_info = sa.orm.relationship("AudioInfo", back_populates="recording")

_processed_action_events = None

@property
Expand All @@ -58,7 +60,6 @@ def processed_action_events(self):
return self._processed_action_events



class ActionEvent(db.Base):
__tablename__ = "action_event"

Expand Down Expand Up @@ -86,8 +87,8 @@ class ActionEvent(db.Base):
children = sa.orm.relationship("ActionEvent")
# TODO: replacing the above line with the following two results in an error:
# AttributeError: 'list' object has no attribute '_sa_instance_state'
#children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent")
#parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children")
# children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent")
# parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children")

recording = sa.orm.relationship("Recording", back_populates="action_events")
screenshot = sa.orm.relationship("Screenshot", back_populates="action_event")
Expand Down Expand Up @@ -269,11 +270,11 @@ def take_screenshot(cls):
sct_img = utils.take_screenshot()
screenshot = Screenshot(sct_img=sct_img)
return screenshot

def crop_active_window(self, action_event):
window_event = action_event.window_event
width_ratio, height_ratio = utils.get_scale_ratios(action_event)

x0 = window_event.left * width_ratio
y0 = window_event.top * height_ratio
x1 = x0 + window_event.width * width_ratio
Expand Down Expand Up @@ -305,6 +306,19 @@ def get_active_window_event(cls):
return WindowEvent(**window.get_active_window_data())


class AudioInfo(db.Base):
__tablename__ = "audio_info"

id = sa.Column(sa.Integer, primary_key=True)
flac_data = sa.Column(sa.LargeBinary)
transcribed_text = sa.Column(sa.String)
recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp"))
sample_rate = sa.Column(sa.Integer)
words_with_timestamps = sa.Column(sa.Text)

recording = sa.orm.relationship("Recording", back_populates="audio_info")


class PerformanceStat(db.Base):
__tablename__ = "performance_stat"

Expand Down
98 changes: 93 additions & 5 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

from openadapt import config, crud, utils, window

import sounddevice
import soundfile
import whisper
import numpy as np
import io

EVENT_TYPES = ("screen", "action", "window")
LOG_LEVEL = "INFO"
Expand All @@ -34,7 +39,6 @@
}
PLOT_PERFORMANCE = False


Event = namedtuple("Event", ("timestamp", "type", "data"))


Expand Down Expand Up @@ -369,6 +373,7 @@ def read_window_events(
window_data = window.get_active_window_data()
if not window_data:
continue

if window_data["title"] != prev_window_data.get("title") or window_data[
"window_id"
] != prev_window_data.get("window_id"):
Expand Down Expand Up @@ -504,14 +509,88 @@ def read_mouse_events(
mouse_listener.stop()


def record(
task_description: str,
):
def record_audio(
terminate_event: multiprocessing.Event,
recording_timestamp: float,
) -> None:
utils.configure_logging(logger, LOG_LEVEL)
utils.set_start_time(recording_timestamp)

audio_frames = [] # to store audio frames

def audio_callback(indata, frames, time, status):
# called whenever there is new audio frames
audio_frames.append(indata.copy())

# open InputStream and start recording while ActionEvents are recorded
audio_stream = sounddevice.InputStream(
callback=audio_callback, samplerate=16000, channels=1
)
logger.info("Audio recording started.")
audio_stream.start()
terminate_event.wait()
audio_stream.stop()
audio_stream.close()

# Concatenate into one Numpy array
concatenated_audio = np.concatenate(audio_frames, axis=0)
# convert concatenated_audio to format expected by whisper
converted_audio = concatenated_audio.flatten().astype(np.float32)

# Convert audio to text using OpenAI's Whisper
logger.info("Transcribing audio...")
model = whisper.load_model("base")
result_info = model.transcribe(converted_audio, word_timestamps=True, fp16=False)
logger.info(f"The narrated text is: {result_info['text']}")
# empty word_list if the user didn't say anything
word_list = []
# segments could be empty
if len(result_info["segments"]) > 0:
# there won't be a 'words' list if the user didn't say anything
if "words" in result_info["segments"][0]:
word_list = result_info["segments"][0]["words"]

# compress and convert to bytes to save to database
logger.info(
"Size of uncompressed audio data: {} bytes".format(converted_audio.nbytes)
)
# Create an in-memory file-like object
file_obj = io.BytesIO()
# Write the audio data using lossless compression
soundfile.write(
file_obj, converted_audio, int(audio_stream.samplerate), format="FLAC"
)
# Get the compressed audio data as bytes
compressed_audio_bytes = file_obj.getvalue()

logger.info(
"Size of compressed audio data: {} bytes".format(len(compressed_audio_bytes))
)

file_obj.close()

# To decompress the audio and restore it to its original form:
# restored_audio, restored_samplerate = sf.read(
# io.BytesIO(compressed_audio_bytes))

# Create AudioInfo entry
crud.insert_audio_info(
compressed_audio_bytes,
result_info["text"],
recording_timestamp,
int(audio_stream.samplerate),
word_list,
)


def record(task_description: str, enable_audio: bool = False):
"""
Record Screenshots/ActionEvents/WindowEvents.
Record Screenshots/ActionEvents/WindowEvents. Optionally record audio narration from the user
describing what tasks are being done.

Args:
task_description: a text description of the task that will be recorded
enable_audio: a flag to enable or disable audio recording (default: False)
"""

utils.configure_logging(logger, LOG_LEVEL)
Expand Down Expand Up @@ -616,6 +695,13 @@ def record(
)
perf_stat_writer.start()

if enable_audio:
audio_recorder = threading.Thread(
target=record_audio,
args=(terminate_event, recording_timestamp),
)
audio_recorder.start()

# TODO: discard events until everything is ready

try:
Expand All @@ -633,6 +719,8 @@ def record(
screen_event_writer.join()
action_event_writer.join()
window_event_writer.join()
if enable_audio:
audio_recorder.join()

terminate_perf_event.set()

Expand Down
Loading