Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: append data on slice index change instead of save on end #268

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions cpp/rn-audioutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,24 @@

namespace rnaudioutils {

std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples) {
std::vector<uint8_t> output_data;

for (size_t i = 0; i < buffers.size(); i++) {
int size = slice_n_samples[i]; // Number of shorts
short* slice = buffers[i];

// Copy each short as two bytes
for (int j = 0; j < size; j++) {
output_data.push_back(static_cast<uint8_t>(slice[j] & 0xFF)); // Lower byte
output_data.push_back(static_cast<uint8_t>((slice[j] >> 8) & 0xFF)); // Higher byte
}
}

return output_data;
}

std::vector<uint8_t> remove_trailing_zeros(const std::vector<uint8_t>& audio_data) {
auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; });
return std::vector<uint8_t>(audio_data.begin(), last.base());
void append_wav_data(const short* data, const int n_samples, const std::string& file) {
std::ofstream output(file, std::ios::binary | std::ios::app);
output.write(reinterpret_cast<const char*>(data), n_samples * sizeof(short));
output.close();
}

void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file) {
std::vector<uint8_t> data = remove_trailing_zeros(raw);

std::ofstream output(file, std::ios::binary);
void add_wav_header_to_file(const std::string& file, const int data_size) {
std::ofstream output(file, std::ios::binary | std::ios::app);

if (!output.is_open()) {
RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str());
return;
}

// WAVE header
output.seekp(0, std::ios::beg);

output.write("RIFF", 4);
int32_t chunk_size = 36 + static_cast<int32_t>(data.size());
int32_t chunk_size = 36 + static_cast<int32_t>(data_size);
output.write(reinterpret_cast<char*>(&chunk_size), sizeof(chunk_size));
output.write("WAVE", 4);
output.write("fmt ", 4);
Expand All @@ -56,13 +39,10 @@ void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file) {
short bits_per_sample = 16;
output.write(reinterpret_cast<char*>(&bits_per_sample), sizeof(bits_per_sample));
output.write("data", 4);
int32_t sub_chunk2_size = static_cast<int32_t>(data.size());
int32_t sub_chunk2_size = static_cast<int32_t>(data_size);
output.write(reinterpret_cast<char*>(&sub_chunk2_size), sizeof(sub_chunk2_size));
output.write(reinterpret_cast<const char*>(data.data()), data.size());

output.close();

RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str());
}

} // namespace rnaudioutils
4 changes: 2 additions & 2 deletions cpp/rn-audioutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace rnaudioutils {

std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples);
void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file);
void append_wav_data(const short* data, const int n_samples, const std::string& file);
void add_wav_header_to_file(const std::string& file, const int data_size);

} // namespace rnaudioutils
1 change: 1 addition & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) {
for (int i = 0; i < n; i++) {
pcm[i + n_samples] = data[i];
}
pcm_data_size += n;
}

float* job::pcm_slice_to_f32(int slice_index, int size) {
Expand Down
1 change: 1 addition & 0 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct job {
int audio_slice_sec = 0;
float audio_min_sec = 0;
const char* audio_output_path = nullptr;
int pcm_data_size = 0;
std::vector<short *> pcm_slices;
void set_realtime_params(vad_params vad, int sec, int slice_sec, float min_sec, const char* output_path);
bool vad_simple(int slice_index, int n_samples, int n);
Expand Down
15 changes: 11 additions & 4 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ void AudioInputCallback(void * inUserData,
- (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NSDictionary*)result {
// Save wav if needed
if (state->job->audio_output_path != nullptr) {
// TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
rnaudioutils::save_wav_file(
rnaudioutils::concat_short_buffers(state->job->pcm_slices, state->sliceNSamples),
state->job->audio_output_path
rnaudioutils::add_wav_header_to_file(
state->job->audio_output_path,
state->job->pcm_data_size
);
}
state->transcribeHandler(state->job->job_id, @"end", result);
Expand Down Expand Up @@ -284,6 +283,14 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state {
state->nSamplesTranscribing == nSamplesOfIndex &&
state->transcribeSliceIndex != state->sliceIndex
) {
if (state->job->audio_output_path != nullptr) {
rnaudioutils::append_wav_data(
state->job->pcm_slices[state->transcribeSliceIndex],
state->sliceNSamples[state->transcribeSliceIndex],
state->job->audio_output_path
);
}
// TODO: Clean up the previous slice
state->transcribeSliceIndex++;
state->nSamplesTranscribing = 0;
}
Expand Down
Loading