Skip to content

Commit

Permalink
Fixed the DTW benchmark.
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Bruse authored and zond committed Jun 13, 2024
1 parent 61ea219 commit 115c4ac
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions cpp/zimt/dtw_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,42 @@ TEST(DTW, ChainDTWTest) {
}

void BM_DTW(benchmark::State& state) {
const hwy::AlignedNDArray<float, 2> spec_a(
hwy::AlignedNDArray<float, 2> spec_a(
{static_cast<size_t>(state.range(0)), 1024});
const hwy::AlignedNDArray<float, 2> spec_b(
hwy::AlignedNDArray<float, 2> spec_b(
{static_cast<size_t>(state.range(0)), 1024});
for (size_t step_index = 0; step_index < spec_a.shape()[0]; ++step_index) {
for (size_t channel_index = 0; channel_index < spec_a.shape()[1];
++channel_index) {
spec_a[{step_index}][channel_index] = 1.0;
}
}

for (auto s : state) {
DTW(spec_a, spec_b);
}
state.SetItemsProcessed(state.range(0) * state.iterations());
}
BENCHMARK_RANGE(BM_DTW, 1000, 10000);
BENCHMARK_RANGE(BM_DTW, 100, 1000);

void BM_ChainDTW(benchmark::State& state) {
const hwy::AlignedNDArray<float, 2> spec_a(
hwy::AlignedNDArray<float, 2> spec_a(
{static_cast<size_t>(state.range(0)), 1024});
const hwy::AlignedNDArray<float, 2> spec_b(
hwy::AlignedNDArray<float, 2> spec_b(
{static_cast<size_t>(state.range(0)), 1024});
for (size_t step_index = 0; step_index < spec_a.shape()[0]; ++step_index) {
for (size_t channel_index = 0; channel_index < spec_a.shape()[1];
++channel_index) {
spec_a[{step_index}][channel_index] = 1.0;
}
}

for (auto s : state) {
ChainDTW(spec_a, spec_b, 2000);
ChainDTW(spec_a, spec_b, 200);
}
state.SetItemsProcessed(state.range(0) * state.iterations());
}
BENCHMARK_RANGE(BM_ChainDTW, 1000, 50000);
BENCHMARK_RANGE(BM_ChainDTW, 100, 5000);

} // namespace

Expand Down

0 comments on commit 115c4ac

Please sign in to comment.