Skip to content

Commit

Permalink
whisper.android : address ARM's big.LITTLE arch by checking cpu info (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
Digipom authored Sep 6, 2023
1 parent 64cb45f commit f990610
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {

suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
require(ptr != 0L)
WhisperLib.fullTranscribe(ptr, data)
val numThreads = WhisperCpuConfig.preferredThreadCount
Log.d(LOG_TAG, "Selecting $numThreads threads")
WhisperLib.fullTranscribe(ptr, numThreads, data)
val textCount = WhisperLib.getTextSegmentCount(ptr)
return@withContext buildString {
for (i in 0 until textCount) {
Expand Down Expand Up @@ -126,7 +128,7 @@ private class WhisperLib {
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.whispercppdemo.whisper

import android.util.Log
import java.io.BufferedReader
import java.io.FileReader

object WhisperCpuConfig {
val preferredThreadCount: Int
// Always use at least 2 threads:
get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
}

private class CpuInfo(private val lines: List<String>) {
private fun getHighPerfCpuCount(): Int = try {
getHighPerfCpuCountByFrequencies()
} catch (e: Exception) {
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
getHighPerfCpuCountByVariant()
}

private fun getHighPerfCpuCountByFrequencies(): Int =
getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
.also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
.countDroppingMin()

private fun getHighPerfCpuCountByVariant(): Int =
getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
.also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
.countKeepingMin()

private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()

private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
.asSequence()
.filter { it.startsWith(property) }
.map { mapper(it.substringAfter(':').trim()) }
.sorted()
.toList()


private fun List<Int>.countDroppingMin(): Int {
val min = min()
return count { it > min }
}

private fun List<Int>.countKeepingMin(): Int {
val min = min()
return count { it == min }
}

companion object {
private const val LOG_TAG = "WhisperCpuConfig"

fun getHighPerfCpuCount(): Int = try {
readCpuInfo().getHighPerfCpuCount()
} catch (e: Exception) {
Log.d(LOG_TAG, "Couldn't read CPU info", e)
// Our best guess -- just return the # of CPUs minus 4.
(Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
}

private fun readCpuInfo() = CpuInfo(
BufferedReader(FileReader("/proc/cpuinfo"))
.useLines { it.toList() }
)

private fun getMaxCpuFrequency(cpuIndex: Int): Int {
val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
return maxFreq.toInt()
}
}
}
8 changes: 2 additions & 6 deletions examples/whisper.android/app/src/main/jni/whisper/jni.c
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(

JNIEXPORT void JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
UNUSED(thiz);
struct whisper_context *context = (struct whisper_context *) context_ptr;
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);

// Leave 2 processors free (i.e. the high-efficiency cores).
int max_threads = max(1, min(8, get_nprocs() - 2));
LOGI("Selecting %d threads", max_threads);

// The below adapted from the Objective-C iOS sample
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
params.print_realtime = true;
Expand All @@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
params.print_special = false;
params.translate = false;
params.language = "en";
params.n_threads = max_threads;
params.n_threads = num_threads;
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;
Expand Down

0 comments on commit f990610

Please sign in to comment.