diff --git a/alembic/versions/3d3239ae4849_add_audio_info.py b/alembic/versions/3d3239ae4849_add_audio_info.py new file mode 100644 index 000000000..aedae40fb --- /dev/null +++ b/alembic/versions/3d3239ae4849_add_audio_info.py @@ -0,0 +1,37 @@ +"""Add audio info + +Revision ID: 3d3239ae4849 +Revises: 104d4a614d95 +Create Date: 2023-05-31 19:36:46.269697 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3d3239ae4849' +down_revision = '104d4a614d95' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('audio_info', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('flac_data', sa.LargeBinary(), nullable=True), + sa.Column('transcribed_text', sa.String(), nullable=True), + sa.Column('recording_timestamp', sa.Integer(), nullable=True), + sa.Column('sample_rate', sa.Integer(), nullable=True), + sa.Column('words_with_timestamps', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['recording_timestamp'], ['recording.timestamp'], name=op.f('fk_audio_info_recording_timestamp_recording')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_audio_info')) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('audio_info') + # ### end Alembic commands ### diff --git a/openadapt/crud.py b/openadapt/crud.py index 08149ee79..62e438029 100644 --- a/openadapt/crud.py +++ b/openadapt/crud.py @@ -1,16 +1,18 @@ from loguru import logger import sqlalchemy as sa +import json from openadapt.db import Session + from openadapt.models import ( ActionEvent, Screenshot, Recording, WindowEvent, PerformanceStat, + AudioInfo ) - BATCH_SIZE = 1 db = Session() @@ -19,6 +21,7 @@ window_events = [] performance_stats = [] + def _insert(event_data, table, buffer=None): """Insert using Core API for improved performance (no rows are returned)""" @@ -74,6 +77,7 @@ def insert_window_event(recording_timestamp, event_timestamp, event_data): } _insert(event_data, WindowEvent, window_events) + def insert_perf_stat(recording_timestamp, event_type, start_time, end_time): """ Insert event performance stat into db @@ -87,6 +91,7 @@ def insert_perf_stat(recording_timestamp, event_type, start_time, end_time): } _insert(event_perf_stat, PerformanceStat, performance_stats) + def get_perf_stats(recording_timestamp): """ return performance stats for a given recording @@ -94,11 +99,32 @@ def get_perf_stats(recording_timestamp): return ( db - .query(PerformanceStat) - .filter(PerformanceStat.recording_timestamp == recording_timestamp) - .order_by(PerformanceStat.start_time) - .all() + .query(PerformanceStat) + .filter(PerformanceStat.recording_timestamp == recording_timestamp) + .order_by(PerformanceStat.start_time) + .all() + ) + + +def insert_audio_info( + audio_data, + transcribed_text, + recording_timestamp, + sample_rate, + word_list +): + """Insert an AudioInfo entry into the database.""" + thread_local_db = Session() + audio_info = AudioInfo( + flac_data=audio_data, + transcribed_text=transcribed_text, + recording_timestamp=recording_timestamp, + sample_rate=sample_rate, + words_with_timestamps=json.dumps(word_list) ) + thread_local_db.add(audio_info) + thread_local_db.commit() + def insert_recording(recording_data): db_obj = Recording(**recording_data) @@ -111,29 +137,29 @@ def insert_recording(recording_data): def get_latest_recording(): return ( db - .query(Recording) - .order_by(sa.desc(Recording.timestamp)) - .limit(1) - .first() + .query(Recording) + .order_by(sa.desc(Recording.timestamp)) + .limit(1) + .first() ) def get_recording(timestamp): return ( db - .query(Recording) - .filter(Recording.timestamp == timestamp) - .first() + .query(Recording) + .filter(Recording.timestamp == timestamp) + .first() ) def _get(table, recording_timestamp): return ( db - .query(table) - .filter(table.recording_timestamp == recording_timestamp) - .order_by(table.timestamp) - .all() + .query(table) + .filter(table.recording_timestamp == recording_timestamp) + .order_by(table.timestamp) + .all() ) @@ -158,3 +184,12 @@ def get_screenshots(recording, precompute_diffs=False): def get_window_events(recording): return _get(WindowEvent, recording.timestamp) + + +def get_audio_info(recording): + return ( + db + .query(AudioInfo) + .filter(AudioInfo.recording_timestamp == recording.timestamp) + .first() + ) diff --git a/openadapt/db.py b/openadapt/db.py index 290b6b2e0..e6c2bd9d8 100644 --- a/openadapt/db.py +++ b/openadapt/db.py @@ -1,6 +1,6 @@ import sqlalchemy as sa from dictalchemy import DictableModel -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.schema import MetaData from sqlalchemy.ext.declarative import declarative_base @@ -50,4 +50,5 @@ def get_base(engine): engine = get_engine() Base = get_base(engine) -Session = sessionmaker(bind=engine) +session_factory = sessionmaker(bind=engine) +Session = scoped_session(session_factory) diff --git a/openadapt/models.py b/openadapt/models.py index e76197695..afccbbaa9 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -48,6 +48,8 @@ class Recording(db.Base): order_by="WindowEvent.timestamp", ) + audio_info = sa.orm.relationship("AudioInfo", back_populates="recording") + _processed_action_events = None @property @@ -58,7 +60,6 @@ def processed_action_events(self): return self._processed_action_events - class ActionEvent(db.Base): __tablename__ = "action_event" @@ -86,8 +87,8 @@ class ActionEvent(db.Base): children = sa.orm.relationship("ActionEvent") # TODO: replacing the above line with the following two results in an error: # AttributeError: 'list' object has no attribute '_sa_instance_state' - #children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent") - #parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children") + # children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent") + # parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children") recording = sa.orm.relationship("Recording", back_populates="action_events") screenshot = sa.orm.relationship("Screenshot", back_populates="action_event") @@ -269,11 +270,11 @@ def take_screenshot(cls): sct_img = utils.take_screenshot() screenshot = Screenshot(sct_img=sct_img) return screenshot - + def crop_active_window(self, action_event): window_event = action_event.window_event width_ratio, height_ratio = utils.get_scale_ratios(action_event) - + x0 = window_event.left * width_ratio y0 = window_event.top * height_ratio x1 = x0 + window_event.width * width_ratio @@ -305,6 +306,19 @@ def get_active_window_event(cls): return WindowEvent(**window.get_active_window_data()) +class AudioInfo(db.Base): + __tablename__ = "audio_info" + + id = sa.Column(sa.Integer, primary_key=True) + flac_data = sa.Column(sa.LargeBinary) + transcribed_text = sa.Column(sa.String) + recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp")) + sample_rate = sa.Column(sa.Integer) + words_with_timestamps = sa.Column(sa.Text) + + recording = sa.orm.relationship("Recording", back_populates="audio_info") + + class PerformanceStat(db.Base): __tablename__ = "performance_stat" diff --git a/openadapt/record.py b/openadapt/record.py index 6f03528ea..5ea03846f 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -24,6 +24,11 @@ from openadapt import config, crud, utils, window +import sounddevice +import soundfile +import whisper +import numpy as np +import io EVENT_TYPES = ("screen", "action", "window") LOG_LEVEL = "INFO" @@ -34,7 +39,6 @@ } PLOT_PERFORMANCE = False - Event = namedtuple("Event", ("timestamp", "type", "data")) @@ -369,6 +373,7 @@ def read_window_events( window_data = window.get_active_window_data() if not window_data: continue + if window_data["title"] != prev_window_data.get("title") or window_data[ "window_id" ] != prev_window_data.get("window_id"): @@ -504,14 +509,88 @@ def read_mouse_events( mouse_listener.stop() -def record( - task_description: str, -): +def record_audio( + terminate_event: multiprocessing.Event, + recording_timestamp: float, +) -> None: + utils.configure_logging(logger, LOG_LEVEL) + utils.set_start_time(recording_timestamp) + + audio_frames = [] # to store audio frames + + def audio_callback(indata, frames, time, status): + # called whenever there is new audio frames + audio_frames.append(indata.copy()) + + # open InputStream and start recording while ActionEvents are recorded + audio_stream = sounddevice.InputStream( + callback=audio_callback, samplerate=16000, channels=1 + ) + logger.info("Audio recording started.") + audio_stream.start() + terminate_event.wait() + audio_stream.stop() + audio_stream.close() + + # Concatenate into one Numpy array + concatenated_audio = np.concatenate(audio_frames, axis=0) + # convert concatenated_audio to format expected by whisper + converted_audio = concatenated_audio.flatten().astype(np.float32) + + # Convert audio to text using OpenAI's Whisper + logger.info("Transcribing audio...") + model = whisper.load_model("base") + result_info = model.transcribe(converted_audio, word_timestamps=True, fp16=False) + logger.info(f"The narrated text is: {result_info['text']}") + # empty word_list if the user didn't say anything + word_list = [] + # segments could be empty + if len(result_info["segments"]) > 0: + # there won't be a 'words' list if the user didn't say anything + if "words" in result_info["segments"][0]: + word_list = result_info["segments"][0]["words"] + + # compress and convert to bytes to save to database + logger.info( + "Size of uncompressed audio data: {} bytes".format(converted_audio.nbytes) + ) + # Create an in-memory file-like object + file_obj = io.BytesIO() + # Write the audio data using lossless compression + soundfile.write( + file_obj, converted_audio, int(audio_stream.samplerate), format="FLAC" + ) + # Get the compressed audio data as bytes + compressed_audio_bytes = file_obj.getvalue() + + logger.info( + "Size of compressed audio data: {} bytes".format(len(compressed_audio_bytes)) + ) + + file_obj.close() + + # To decompress the audio and restore it to its original form: + # restored_audio, restored_samplerate = sf.read( + # io.BytesIO(compressed_audio_bytes)) + + # Create AudioInfo entry + crud.insert_audio_info( + compressed_audio_bytes, + result_info["text"], + recording_timestamp, + int(audio_stream.samplerate), + word_list, + ) + + +def record(task_description: str, enable_audio: bool = False): """ - Record Screenshots/ActionEvents/WindowEvents. + Record Screenshots/ActionEvents/WindowEvents. Optionally record audio narration from the user + describing what tasks are being done. Args: task_description: a text description of the task that will be recorded + enable_audio: a flag to enable or disable audio recording (default: False) """ utils.configure_logging(logger, LOG_LEVEL) @@ -616,6 +695,13 @@ def record( ) perf_stat_writer.start() + if enable_audio: + audio_recorder = threading.Thread( + target=record_audio, + args=(terminate_event, recording_timestamp), + ) + audio_recorder.start() + # TODO: discard events until everything is ready try: @@ -633,6 +719,8 @@ def record( screen_event_writer.join() action_event_writer.join() window_event_writer.join() + if enable_audio: + audio_recorder.join() terminate_perf_event.set() diff --git a/openadapt/strategies/mixins/openai.py b/openadapt/strategies/mixins/openai.py index 2698cf61b..353be33f0 100644 --- a/openadapt/strategies/mixins/openai.py +++ b/openadapt/strategies/mixins/openai.py @@ -11,7 +11,6 @@ class MyReplayStrategy(OpenAIReplayStrategyMixin): from loguru import logger import openai -import tiktoken from openadapt.strategies.base import BaseReplayStrategy from openadapt import cache, config, models @@ -30,7 +29,6 @@ class MyReplayStrategy(OpenAIReplayStrategyMixin): MODEL_NAME = "gpt-4" openai.api_key = config.OPENAI_API_KEY -encoding = tiktoken.get_encoding("cl100k_base") class OpenAIReplayStrategyMixin(BaseReplayStrategy): @@ -160,49 +158,3 @@ def _get_completion( logger.debug(f"appending assistant_message=\n{pformat(assistant_message)}") messages.append(assistant_message) return messages - - -# XXX TODO not currently in use -# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): - """Returns the number of tokens used by a list of messages.""" - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.info("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model == "gpt-3.5-turbo": - logger.info( - "Warning: gpt-3.5-turbo may change over time. Returning num tokens " - "assuming gpt-3.5-turbo-0301." - ) - return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") - elif model == "gpt-4": - logger.info( - "Warning: gpt-4 may change over time. Returning num tokens " - "assuming gpt-4-0314." - ) - return num_tokens_from_messages(messages, model="gpt-4-0314") - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = ( - 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - ) - tokens_per_name = -1 # if there's a name, the role is omitted - elif model == "gpt-4-0314": - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"""num_tokens_from_messages() is not implemented for model " - "{model}. See " - "https://github.com/openai/openai-python/blob/main/chatml.md for " - information on how messages are converted to tokens.""") - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens diff --git a/openadapt/visualize.py b/openadapt/visualize.py index 7fec235ae..d893ea373 100644 --- a/openadapt/visualize.py +++ b/openadapt/visualize.py @@ -11,6 +11,7 @@ from openadapt.crud import ( get_latest_recording, + get_audio_info ) from openadapt.events import ( get_events, @@ -149,6 +150,10 @@ def main(): scrub.scrub_text(recording.task_description) logger.debug(f"{recording=}") + audio_info = row2dict(get_audio_info(recording)) + # don't display the FLAC data + del audio_info['flac_data'] + meta = {} action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta) event_dicts = rows2dicts(action_events) @@ -178,6 +183,11 @@ def main(): width_policy="max", ), ), + row( + Div( + text=f"{dict2html(audio_info)}", + ), + ), ] logger.info(f"{len(action_events)=}") for idx, action_event in enumerate(action_events): diff --git a/requirements.txt b/requirements.txt index ca5b3150e..a5eaf9d11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,17 +25,17 @@ pywin32==306; sys_platform == 'win32' pytesseract==0.3.7 -e git+https://github.com/abrichr/pynput.git#egg=pynput pytest==7.1.3 +python-dotenv==1.0.0 +python-Levenshtein==0.21.1 rapidocr-onnxruntime==1.2.3 scikit-learn==1.2.2 scipy==1.9.3 setuptools-lint +sounddevice~=0.4.6 +soundfile~=0.12.1 sphinx sqlalchemy==1.4.43 -sumy==0.11.0 -tiktoken==0.4.0 torch==2.0.0 tqdm==4.64.0 -nicegui==1.2.16 transformers==4.29.2 -python-dotenv==1.0.0 -python-Levenshtein==0.21.1 \ No newline at end of file +git+https://github.com/openai/whisper.git