Skip to content

Commit

Permalink
refreshing tabuli to changes in main
Browse files Browse the repository at this point in the history
plus further simplifications

~2-3 % speedup by avoiding one sqrt-square pair that did nothing

Before:

|Score type |MSE               |Min score         |Max score         |Mean score        |
|-----------|------------------|------------------|------------------|------------------|
|Zimtohrli  |0.101331579109522 |0.567006001464443 |0.774702668464531 |0.690847270049710 |
|ViSQOL     |0.115330916105424 |0.520833375452983 |0.801480831107469 |0.675101633981268 |
|2f         |0.129541391104905 |0.484687555319526 |0.797475783883375 |0.661870345773127 |
|PESQ       |0.147425552045669 |0.342342966279351 |0.841271127756762 |0.647128996775172 |
|CDPAM      |0.153471222942756 |0.441558428344727 |0.728779141125759 |0.620699318941738 |
|PARLAQ     |0.185057687192323 |0.445261140223642 |0.784370761057963 |0.587162756572532 |
|AQUA       |0.223207996944378 |0.331645933512413 |0.739286336419790 |0.547804951221731 |
|PEAQB      |0.225217321572038 |0.278744167467764 |0.851011116004117 |0.553935720513487 |
|DPAM       |0.315810440183130 |0.186717781679534 |0.690564701717118 |0.460415212267967 |
|WARP-Q     |0.339686211572685 |0.067600137543649 |0.777119464646524 |0.475793617709890 |
|GVPMOS     |0.412937133868407 |0.006851162794410 |0.783946603687895 |0.412912222208318 |

After:

|Score type |MSE               |Min score         |Max score         |Mean score        |
|-----------|------------------|------------------|------------------|------------------|
|Zimtohrli  |0.101180623602238 |0.563709113431054 |0.773872237318597 |0.691051313421366 |
|ViSQOL     |0.115330916105424 |0.520833375452983 |0.801480831107469 |0.675101633981268 |
|2f         |0.129541391104905 |0.484687555319526 |0.797475783883375 |0.661870345773127 |
|PESQ       |0.147425552045669 |0.342342966279351 |0.841271127756762 |0.647128996775172 |
|CDPAM      |0.153471222942756 |0.441558428344727 |0.728779141125759 |0.620699318941738 |
|PARLAQ     |0.185057687192323 |0.445261140223642 |0.784370761057963 |0.587162756572532 |
|AQUA       |0.223207996944378 |0.331645933512413 |0.739286336419790 |0.547804951221731 |
|PEAQB      |0.225217321572038 |0.278744167467764 |0.851011116004117 |0.553935720513487 |
|DPAM       |0.315810440183130 |0.186717781679534 |0.690564701717118 |0.460415212267967 |
|WARP-Q     |0.339686211572685 |0.067600137543649 |0.777119464646524 |0.475793617709890 |
|GVPMOS     |0.412937133868407 |0.006851162794410 |0.783946603687895 |0.412912222208318 |

real	8m5.190s
user	337m41.200s
sys	163m45.641s
  • Loading branch information
jyrkialakuijala committed Jun 14, 2024
1 parent 9fc866a commit e5142d4
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 359 deletions.
2 changes: 1 addition & 1 deletion configure.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ elif [ "${1}" == "asan" ]; then
(cd asan_build && cmake -G Ninja -DCMAKE_C_FLAGS='-fsanitize=address -fPIC' -DCMAKE_CXX_FLAGS='-fsanitize=address -fPIC' -DCMAKE_LINKER_FLAGS_DEBUG='-fsanitize=address' -DCMAKE_BUILD_TYPE=RelWithDebInfo ..)
else
mkdir -p build
(cd build && cmake -G Ninja -DCMAKE_C_FLAGS='-fPIC -mavx2' -DCMAKE_CXX_FLAGS='-fPIC -mavx2' -DCMAKE_BUILD_TYPE=Release ..)
(cd build && cmake -G Ninja -DCMAKE_C_FLAGS='-fPIC -march=native -O3' -DCMAKE_CXX_FLAGS='-fPIC -march=native -O3' -DCMAKE_BUILD_TYPE=Release ..)
fi
297 changes: 33 additions & 264 deletions cpp/zimt/fourier_bank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,80 +43,33 @@ float GetRotatorGains(int i) {
return kRotatorGains[i];
}

int Rotators::FindMedian3xLeaker(float window) {
// Approximate filter delay. TODO: optimize this value along with gain values.
// Recordings can sound better with -2.32 as it pushes the bass signals a bit
// earlier and likely compensates human hearing's deficiency for temporal
// separation.
const float kMagic = -2.2028003503591482;
const float kAlmostHalfForRounding = 0.4687;
return static_cast<int>(kMagic / log(window) + kAlmostHalfForRounding);
}

void Rotators::Filter(hwy::Span<const float> signal,
hwy::AlignedNDArray<float, 2>& channels) {
const int audio_channel = 0;

size_t out_ix = 0;
OccasionallyRenormalize();
for (int64_t i = 0; i < signal.size(); ++i) {
for (int k = 0; k < kNumRotators; ++k) {
<<<<<<< HEAD
<<<<<<< HEAD
int64_t delayed_ix = i - 7; // advance[k] * 0.2;
=======
int64_t delayed_ix = i - advance[k];
>>>>>>> 62123e9 (Hacked together a replacement of the ellitic filters with the tabuli)
=======
int64_t delayed_ix = i - 7; // advance[k] * 0.2;
>>>>>>> c2dc9b5 (...)
float sample = 0;
if (delayed_ix > 0) {
sample = signal[delayed_ix];
}
AddAudio(audio_channel, k, sample);
if ((i & 0xfff) == 0) {
OccasionallyRenormalize();
}
IncrementAll();
if (i >= max_delay_) {
for (int k = 0; k < kNumRotators; ++k) {
float amplitude =
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> c2dc9b5 (...)
std::sqrt(channel[0].accu[4][k] * channel[0].accu[4][k] +
channel[0].accu[5][k] * channel[0].accu[5][k]);
const float windowM1 = 1 - window[k];
amplitude *= std::sqrt(windowM1) * windowM1;
channels[{out_ix}][k] = 2.2 * amplitude;
<<<<<<< HEAD
=======
std::sqrt(rot[2][k] * rot[2][k] + rot[3][k] * rot[3][k]);
channels[{out_ix}][k] = HardClip(amplitude);
>>>>>>> 62123e9 (Hacked together a replacement of the ellitic filters with the tabuli)
=======
>>>>>>> c2dc9b5 (...)
}
++out_ix;
IncrementAll(signal[i]);
for (int k = 0; k < kNumRotators; ++k) {
float energy =
channel[0].accu[4][k] * channel[0].accu[4][k] +
channel[0].accu[5][k] * channel[0].accu[5][k];
channels[{out_ix}][k] = energy;
}
++out_ix;
}
}

<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> c2dc9b5 (...)
double CalculateBandwidth(double low, double mid, double high) {
const double geo_mean_low = std::sqrt(low * mid);
const double geo_mean_high = std::sqrt(mid * high);
return std::abs(geo_mean_high - mid) + std::abs(mid - geo_mean_low);
}

<<<<<<< HEAD
=======
>>>>>>> 62123e9 (Hacked together a replacement of the ellitic filters with the tabuli)
=======
>>>>>>> c2dc9b5 (...)
Rotators::Rotators(int num_channels, std::vector<float> frequency,
std::vector<float> filter_gains, const float sample_rate,
float global_gain) {
Expand All @@ -126,234 +79,50 @@ Rotators::Rotators(int num_channels, std::vector<float> frequency,
// of triple leaking integrator.
float kWindow = 0.9996;
float w40Hz = std::pow(kWindow, 128.0 / kNumRotators); // at 40 Hz.
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> c2dc9b5 (...)
float bw = CalculateBandwidth(
i == 0 ? frequency[1] : frequency[i - 1], frequency[i],
i + 1 == kNumRotators ? frequency[i - 1] : frequency[i + 1]);
window[i] = std::pow(kWindow, bw * 0.5 * 1.4);
<<<<<<< HEAD
=======
window[i] = pow(w40Hz, std::max(1.0, frequency[i] / 40.0));
>>>>>>> 62123e9 (Hacked together a replacement of the ellitic filters with the tabuli)
=======
>>>>>>> c2dc9b5 (...)
delay[i] = FindMedian3xLeaker(window[i]);
window[i] = std::pow(kWindow, bw * 0.7018);
float windowM1 = 1.0f - window[i];
max_delay_ = std::max(max_delay_, delay[i]);
float f = frequency[i] * 2.0f * M_PI / sample_rate;
gain[i] = filter_gains[i] * global_gain * pow(windowM1, 3.0);
gain[i] = 2.0 * filter_gains[i] * global_gain * pow(windowM1, 3.0);
rot[0][i] = float(std::cos(f));
rot[1][i] = float(-std::sin(f));
rot[2][i] = sqrt(gain[i]);
rot[2][i] = gain[i];
rot[3][i] = 0.0f;
}
for (size_t i = 0; i < kNumRotators; ++i) {
advance[i] = max_delay_ - delay[i];
}
rotator_frequency = frequency;
}

void Rotators::Increment(int c, int i, float audio) {
if (c == 0) {
float tr = rot[0][i] * rot[2][i] - rot[1][i] * rot[3][i];
float tc = rot[0][i] * rot[3][i] + rot[1][i] * rot[2][i];
rot[2][i] = tr;
rot[3][i] = tc;
}
channel[c].accu[0][i] *= window[i];
channel[c].accu[1][i] *= window[i];
channel[c].accu[2][i] *= window[i];
channel[c].accu[3][i] *= window[i];
channel[c].accu[4][i] *= window[i];
channel[c].accu[5][i] *= window[i];
channel[c].accu[0][i] += rot[2][i] * audio;
channel[c].accu[1][i] += rot[3][i] * audio;
channel[c].accu[2][i] += channel[c].accu[0][i];
channel[c].accu[3][i] += channel[c].accu[1][i];
channel[c].accu[4][i] += channel[c].accu[2][i];
channel[c].accu[5][i] += channel[c].accu[3][i];
}

void Rotators::AddAudio(int c, int i, float audio) {
channel[c].accu[0][i] += rot[2][i] * audio;
channel[c].accu[1][i] += rot[3][i] * audio;
}
void Rotators::OccasionallyRenormalize() {
for (int i = 0; i < kNumRotators; ++i) {
float norm =
sqrt(gain[i] / (rot[2][i] * rot[2][i] + rot[3][i] * rot[3][i]));
float norm = gain[i] / sqrt(rot[2][i] * rot[2][i] + rot[3][i] * rot[3][i]);
rot[2][i] *= norm;
rot[3][i] *= norm;
}
}
void Rotators::IncrementAll() {

void Rotators::IncrementAll(float signal) {
for (int i = 0; i < kNumRotators; i++) {
const float tr = rot[0][i] * rot[2][i] - rot[1][i] * rot[3][i];
const float tc = rot[0][i] * rot[3][i] + rot[1][i] * rot[2][i];
rot[2][i] = tr;
rot[3][i] = tc;
}
for (int c = 0; c < channel.size(); ++c) {
for (int i = 0; i < kNumRotators; i++) {
const float w = window[i];
channel[c].accu[0][i] *= w;
channel[c].accu[1][i] *= w;
channel[c].accu[2][i] *= w;
channel[c].accu[3][i] *= w;
channel[c].accu[4][i] *= w;
channel[c].accu[5][i] *= w;
channel[c].accu[2][i] += channel[c].accu[0][i];
channel[c].accu[3][i] += channel[c].accu[1][i];
channel[c].accu[4][i] += channel[c].accu[2][i];
channel[c].accu[5][i] += channel[c].accu[3][i];
}
}
}
float Rotators::GetSampleAll(int c) {
float retval = 0;
for (int i = 0; i < kNumRotators; ++i) {
retval +=
(rot[2][i] * channel[c].accu[4][i] + rot[3][i] * channel[c].accu[5][i]);
}
return retval;
}
float Rotators::GetSample(int c, int i, FilterMode mode) const {
return (
mode == IDENTITY ? (rot[2][i] * channel[c].accu[4][i] +
rot[3][i] * channel[c].accu[5][i])
: mode == AMPLITUDE
? std::sqrt(gain[i] * (channel[c].accu[4][i] * channel[c].accu[4][i] +
channel[c].accu[5][i] * channel[c].accu[5][i]))
: std::atan2(channel[c].accu[4][i], channel[c].accu[5][i]));
}

float BarkFreq(float v) {
constexpr float linlogsplit = 0.1;
if (v < linlogsplit) {
return 20.0 + (v / linlogsplit) * 20.0; // Linear 20-40 Hz.
} else {
float normalized_v = (v - linlogsplit) * (1.0 / (1.0 - linlogsplit));
return 40.0 * pow(500.0, normalized_v); // Logarithmic 40-20000 Hz.
}
}

float HardClip(float v) { return std::max(-1.0f, std::min(1.0f, v)); }

RotatorFilterBank::RotatorFilterBank(size_t num_rotators, size_t num_channels,
size_t samplerate, size_t num_threads,
const std::vector<float>& filter_gains,
float global_gain) {
num_rotators_ = num_rotators;
num_channels_ = num_channels;
num_threads_ = num_threads;
std::vector<float> freqs(num_rotators);
for (size_t i = 0; i < num_rotators_; ++i) {
freqs[i] = BarkFreq(static_cast<float>(i) / (num_rotators_ - 1));
// printf("%d %g\n", i, freqs[i]);
}
rotators_.reset(
new Rotators(num_channels, freqs, filter_gains, samplerate, global_gain));

max_delay_ = rotators_->max_delay_;
QCHECK_LE(max_delay_, kBlockSize);
fprintf(stderr, "Rotator bank output delay: %zu\n", max_delay_);
filter_outputs_.resize(num_rotators);
for (std::vector<float>& output : filter_outputs_) {
output.resize(num_channels_ * kBlockSize, 0.f);
}
}

// TODO(jyrki): filter all at once in the generic case, filtering one
// is not memory friendly in this memory tabulation.
void RotatorFilterBank::FilterOne(size_t f_ix, const float* history,
int64_t total_in, int64_t len,
FilterMode mode, float* output) {
size_t out_ix = 0;
for (int64_t i = 0; i < len; ++i) {
int64_t delayed_ix = total_in + i - rotators_->advance[f_ix];
size_t histo_ix = num_channels_ * (delayed_ix & kHistoryMask);
for (size_t c = 0; c < num_channels_; ++c) {
float delayed = history[histo_ix + c];
rotators_->Increment(c, f_ix, delayed);
}
if (total_in + i >= max_delay_) {
for (size_t c = 0; c < num_channels_; ++c) {
output[out_ix * num_channels_ + c] =
rotators_->GetSample(c, f_ix, mode);
}
++out_ix;
}
}
}

int64_t RotatorFilterBank::FilterAllSingleThreaded(const float* history,
int64_t total_in,
int64_t len, FilterMode mode,
float* output,
size_t output_size) {
size_t out_ix = 0;
for (size_t c = 0; c < num_channels_; ++c) {
rotators_->OccasionallyRenormalize();
}
for (int64_t i = 0; i < len; ++i) {
for (size_t c = 0; c < num_channels_; ++c) {
for (int k = 0; k < kNumRotators; ++k) {
int64_t delayed_ix = total_in + i - rotators_->advance[k];
size_t histo_ix = num_channels_ * (delayed_ix & kHistoryMask);
float delayed = history[histo_ix + c];
rotators_->AddAudio(c, k, delayed);
}
}
rotators_->IncrementAll();
if (total_in + i >= max_delay_) {
for (size_t c = 0; c < num_channels_; ++c) {
output[out_ix * num_channels_ + c] =
HardClip(rotators_->GetSampleAll(c));
}
++out_ix;
}
}
size_t out_len = total_in < max_delay_
? std::max<int64_t>(0, len - (max_delay_ - total_in))
: len;
return out_len;
}

int64_t RotatorFilterBank::FilterAll(const float* history, int64_t total_in,
int64_t len, FilterMode mode,
float* output, size_t output_size) {
auto run = [&](size_t thread) {
while (true) {
size_t my_task = next_task_++;
if (my_task >= num_rotators_) return;
FilterOne(my_task, history, total_in, len, mode,
filter_outputs_[my_task].data());
}
};
next_task_ = 0;
std::vector<std::future<void>> futures;
futures.reserve(num_threads_);
for (size_t i = 0; i < num_threads_; ++i) {
futures.push_back(std::async(std::launch::async, run, i));
}
for (size_t i = 0; i < num_threads_; ++i) {
futures[i].get();
}
size_t out_len = total_in < max_delay_
? std::max<int64_t>(0, len - (max_delay_ - total_in))
: len;
for (size_t i = 0; i < out_len; ++i) {
for (size_t j = 0; j < num_rotators_; ++j) {
for (size_t c = 0; c < num_channels_; ++c) {
size_t out_idx = (i * num_rotators_ + j) * num_channels_ + c;
output[out_idx] = filter_outputs_[j][i * num_channels_ + c];
}
}
}
return out_len;
}

} // namespace tabuli
const float w = window[i];
int c = 0;
channel[c].accu[0][i] *= w;
channel[c].accu[1][i] *= w;
channel[c].accu[2][i] *= w;
channel[c].accu[3][i] *= w;
channel[c].accu[4][i] *= w;
channel[c].accu[5][i] *= w;
channel[c].accu[2][i] += channel[c].accu[0][i];
channel[c].accu[3][i] += channel[c].accu[1][i];
channel[c].accu[4][i] += channel[c].accu[2][i];
channel[c].accu[5][i] += channel[c].accu[3][i];
channel[c].accu[0][i] += rot[2][i] * signal;
channel[c].accu[1][i] += rot[3][i] * signal;
}
}

} // namespace tabuli
Loading

0 comments on commit e5142d4

Please sign in to comment.