diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index da9aaddfb0d6..f833df1be523 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -16,27 +16,401 @@ * specific language governing permissions and limitations * under the License. */ + +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + byteOffset: number; + nbytes: number; +} + +export interface NDArrayShardEntry { + dataPath: string; + format: "raw-shard"; + nbytes: number; + records: Array; +} + /** * Common Interface for the artifact cache */ export interface ArtifactCacheTemplate { /** - * fetch key url from cache + * Retrieve data object that corresponds to `url` from cache. If data object does not exist in + * cache, fetch the data and then add to cache. + * + * @param url: The url to the data to be cached. + * @param storetype: This field is required so that `ArtifactIndexedDBCache` can store the + * actual data object (see `addToCache()`), while `ArtifactCache` which uses the Cache API can + * return the actual data object rather than the request. There are two options: + * 1. "json": returns equivalent to `fetch(url).json()` + * 2. "arraybuffer": returns equivalent to `fetch(url).arraybuffer()` + * @return The data object (i.e. users do not need to call `.json()` or `.arraybuffer()`). + * + * @note This is an async function. */ - fetchWithCache(url: string); + fetchWithCache(url: string, storetype?: string): Promise; /** - * add ey url to cache + * Fetch data from url and add into cache. If already exists in cache, should return instantly. + * + * @param url: The url to the data to be cached. + * @param storetype: Only applies to `ArtifactIndexedDBCache`. Since `indexedDB` stores the actual + * data rather than a request, we specify `storagetype`. There are two options: + * 1. "json": IndexedDB stores `fetch(url).json()` + * 2. "arraybuffer": IndexedDB stores `fetch(url).arrayBuffer()` + * + * @note This is an async function. */ - addToCache(url: string); + addToCache(url: string, storetype?: string): Promise; /** * check if cache has all keys in Cache + * + * @note This is an async function. */ - hasAllKeys(keys: string[]); + hasAllKeys(keys: string[]): Promise; /** * Delete url in cache if url exists + * + * @note This is an async function. + */ + deleteInCache(url: string): Promise; +} + + +/** + * Cache to store model related data, implemented with the Cache API. + */ +export class ArtifactCache implements ArtifactCacheTemplate { + private scope: string; + private cache?: Cache; + + constructor(scope: string) { + this.scope = scope; + } + + /** + * Convert the Response object to the expected storetype instead */ - deleteInCache(url: string); + async responseTostoretype(response: Response, storetype?: string): Promise { + if (storetype === undefined) { + return response; + } else if (storetype.toLowerCase() === "json") { + return await response.json(); + } else if (storetype.toLowerCase() === "arraybuffer") { + return await response.arrayBuffer(); + } else { + console.error("Unknown storage type " + storetype + ", returning raw response"); + return response; + } + } + + /** + * fetch the corresponding url object in response or stored object format + * @param url url + * @param storetype the storage type for indexedDB + * @returns response in json, arraybuffer or pure response format + */ + async fetchWithCache(url: string, storetype?: string): Promise { + await this.addToCache(url, storetype); + const result = await this.cache.match(new Request(url)); + if (result === undefined) { + // Already called `addToCache()`, should expect the request in cache. + throw Error("Cannot fetch " + url); + } + return await this.responseTostoretype(result, storetype); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async addToCache(url: string, storetype?: string) { + const request = new Request(url); + if (this.cache === undefined) { + this.cache = await caches.open(this.scope); + } + const result = await this.cache.match(request); + if (result === undefined) { + await this.cache.add(request); + } + } + + /** + * Determine if all keys exist in the cache + * @param keys the url key list of the strings + * @returns boolean value indicate if all keys are in cache + */ + async hasAllKeys(keys: string[]) { + if (this.cache === undefined) { + this.cache = await caches.open(this.scope); + } + return this.cache.keys() + .then(requests => requests.map(request => request.url)) + .then(cacheKeys => keys.every(key => cacheKeys.indexOf(key) !== -1)) + .catch(() => false); + } + + /** + * Delete the corresponding url object in cache + * @param url the corresponding url object to be deleted + */ + async deleteInCache(url: string) { + if (this.cache === undefined) { + this.cache = await caches.open(this.scope); + } + await this.cache.delete(url); + } +} + +/** + * Cache by IndexedDB to support caching model data + */ +export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { + private dbName?: string; + private dbVersion = 1; + private db: IDBDatabase | undefined; + + constructor(dbName: string) { + this.dbName = dbName; + } + + /** + * Init the indexed DB database if it is not initialized. + */ + private async initDB() { + if (this.db != null) { + return; // the db is already inialized + } + return new Promise((resolve, reject) => { + const request = indexedDB.open(this.dbName, this.dbVersion); + request.onupgradeneeded = (event) => { + this.db = (event.target as IDBOpenDBRequest).result; + if (!this.db.objectStoreNames.contains('urls')) { + this.db.createObjectStore('urls', { keyPath: 'url' }); + } + }; + request.onsuccess = (event) => { + this.db = (event.target as IDBOpenDBRequest).result; + resolve(); + }; + request.onerror = (event) => { + console.error("Database error: ", (event.target as IDBOpenDBRequest).error); + reject((event.target as IDBOpenDBRequest).error); + }; + }); + } + + /** + * Check if current url object is in indexedDB or not + * @param url the url link + * @returns boolean indicate if url object in indexedDB + */ + private async isUrlInDB(url: string): Promise { + return new Promise((resolve, reject) => { + const transaction = this.db?.transaction(['urls'], 'readonly'); + if (transaction === undefined) { + return false; + } + const store = transaction.objectStore('urls'); + const request = store.get(url); + request.onsuccess = () => { + resolve(request.result !== undefined); + }; + request.onerror = (event) => { + reject((event.target as IDBRequest).error); + }; + }); + } + + async asyncGetHelper(url: string): Promise { + return new Promise((resolve, reject) => { + let result: any; + const transaction = this.db?.transaction(['urls'], 'readonly'); + if (transaction === undefined) { + return false; + } + transaction.oncomplete = () => resolve(result); + transaction.onerror = () => reject(transaction.error); + const objectStore = transaction.objectStore('urls'); + const getRequest = objectStore.get(url); + getRequest.onsuccess = () => { + result = getRequest.result; + } + }) + } + + async fetchWithCache(url: string, storetype?: string): Promise { + await this.addToCache(url, storetype); + let result = await this.asyncGetHelper(url); + if (result === null) { + // previously null data in cache or somehow failed to add to cache, delete and retry + await this.deleteInCache(url); + await this.addToCache(url, storetype); + result = await this.asyncGetHelper(url); + } + if (result != null && typeof result === "object" && "data" in result) { + // `storetype` not used here because the data stored in indexedDB is already in that type + return result.data; + } + throw Error("ArtifactIndexedDBCache failed to fetch: " + url); + } + + async addToIndexedDB(url: string, response: any, storetype?: string) { + await this.initDB(); + let data: any; + // IndexedDB, unlike the Cache API, stores the actual data object, so we convert reponse here. + if (storetype != undefined) { + if (storetype.toLowerCase() === "json") { + data = await response.json(); + } else if (storetype.toLocaleLowerCase() === "arraybuffer") { + data = await response.arrayBuffer(); + } else { + throw Error("Unsupported storetyp for IndexedDB: " + storetype); + } + } + return new Promise((resolve, reject) => { + const transaction = this.db?.transaction(['urls'], 'readwrite'); + if (transaction === undefined) { + return; + } + const store = transaction.objectStore('urls'); + const request = store.add({ data, url }); // Index DB follows a {value, key} format, instead of {key, value} format! + request.onsuccess = () => resolve(); + request.onerror = (event) => reject((event.target as IDBRequest).error); + }); + } + + async addToCache(url: string, storetype?: string): Promise { + await this.initDB(); // await the initDB process + // If already cached, nothing to do + const isInDB = await this.isUrlInDB(url); + if (isInDB) { + return; + } + try { + const response = await fetch(url); + if (!response.ok) { + throw new Error('Network response was not ok'); + } + const response_copy = response.clone(); + await this.addToIndexedDB(url, response_copy, storetype); + } catch (error) { + throw Error("Failed to store " + url + " with error: " + error); + } + } + + async hasAllKeys(keys: string[]): Promise { + await this.initDB(); // Ensure the DB is initialized + if (!this.db) { + throw new Error('Database is not initialized'); + } + return new Promise((resolve, reject) => { + const transaction = this.db.transaction(['urls'], 'readonly'); + const store = transaction.objectStore('urls'); + const promises = keys.map(key => { + return new Promise((resolve) => { + const request = store.get(key); + request.onsuccess = () => { + if (request.result === undefined) { + resolve(false); // Key not found, resolve with false + } else { + resolve(true); // Key found, resolve with true + } + }; + request.onerror = () => { + resolve(false); // On error, resolve as if the key was not found + }; + }); + }); + Promise.all(promises).then(results => { + const allExist = results.every(exists => exists); + resolve(allExist); + }).catch(error => { + reject(error); // Reject the main promise if any of the promises are rejected + }); + }); + } + + async deleteInCache(url: string) { + await this.initDB(); // Make sure the DB is initialized + const transaction = this.db?.transaction(['urls'], 'readwrite'); + if (transaction === undefined) { + return; + } + const store = transaction.objectStore('urls'); + const request = store.delete(url); + // Await completion of the delete request + await new Promise((resolve, reject) => { + request.onsuccess = () => resolve(); + request.onerror = () => reject(request.error); + }); + return; + } +} + + +/** + * Function to check if NDarray is in Cache or not + * + * @param ndarrayCacheUrl The cache url which links to the NDArray + * @param cacheScope The scope identifier of the cache + * @param cacheType The type of the cache: "cache" or "indexedDB" + * @returns the result if the cache has NDArray + */ +export async function hasNDArrayInCache( + ndarrayCacheUrl: string, + cacheScope = "tvmjs", + cacheType = "cache" +): Promise { + let artifactCache: ArtifactCacheTemplate; + if (cacheType.toLowerCase() === "cache") { + artifactCache = new ArtifactCache(cacheScope); + } else if (cacheType.toLowerCase() == "indexeddb") { + artifactCache = new ArtifactIndexedDBCache(cacheScope); + } else { + console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); + artifactCache = new ArtifactCache(cacheScope); + } + const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + const hasJsonUrlInCache = await artifactCache.hasAllKeys([jsonUrl]); + if (!hasJsonUrlInCache) { + return false; + } + let list = await artifactCache.fetchWithCache(jsonUrl, "json"); + list = list["records"] as Array; + return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); +} + + +/** + * Given cacheUrl, search up items to delete based on cacheUrl/ndarray-cache.json + * + * @param cacheUrl The cacheUrl for the items + * @param cacheScope The scope identifier of the cache + * @param cacheType The type of the cache: "cache" or "indexedDB" + */ +export async function deleteNDArrayCache( + cacheUrl: string, + cacheScope = "tvmjs", + cacheType = "cache" +) { + let artifactCache: ArtifactCacheTemplate; + if (cacheType.toLowerCase() === "cache") { + artifactCache = new ArtifactCache(cacheScope); + } else if (cacheType.toLowerCase() == "indexeddb") { + artifactCache = new ArtifactIndexedDBCache(cacheScope); + } else { + console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); + artifactCache = new ArtifactCache(cacheScope); + } + const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href; + const list = await artifactCache.fetchWithCache(jsonUrl, "json"); + const arrayentry = list["records"] as Array; + const processShard = async (i: number) => { + const dataUrl = new URL(arrayentry[i].dataPath, cacheUrl).href; + await artifactCache.deleteInCache(dataUrl); + } + await Promise.all(arrayentry.map((_, index) => processShard(index))); } diff --git a/web/src/index.ts b/web/src/index.ts index edc695978f50..d4fc9b9187e6 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -22,11 +22,17 @@ export { PackedFunc, Module, NDArray, TVMArray, TVMObject, VirtualMachine, InitProgressCallback, InitProgressReport, - ArtifactCache, Instance, instantiate, hasNDArrayInCache, deleteNDArrayCache + Instance, instantiate } from "./runtime"; +export { + ArtifactCacheTemplate, + ArtifactCache, + ArtifactIndexedDBCache, + hasNDArrayInCache, + deleteNDArrayCache +} from "./artifact_cache"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; -export { wasmPath, LinearCongruentialGenerator } from "./support"; +export { assert, wasmPath, LinearCongruentialGenerator } from "./support"; export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu"; -export { assert } from "./support"; export { createPolyfillWASI } from "./compact"; diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 9142571b9e4a..4b40bbc34152 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -27,8 +27,12 @@ import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./suppo import { Environment } from "./environment"; import { AsyncifyHandler } from "./asyncify"; import { FunctionInfo, WebGPUContext } from "./webgpu"; -import { ArtifactCacheTemplate } from "./artifact_cache"; - +import { + ArtifactCache, + ArtifactCacheTemplate, + ArtifactIndexedDBCache, + NDArrayShardEntry, +} from "./artifact_cache"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; @@ -970,88 +974,15 @@ enum AsyncCallbackCode { kReturn = 4, kException = 5, } -export interface NDArrayCacheEntry { - name: string; - shape: Array; - dtype: string; - format: "f32-to-bf16" | "raw"; - byteOffset: number; - nbytes: number; -} - -export interface NDArrayShardEntry { - dataPath: string; - format: "raw-shard"; - nbytes: number; - records: Array; -} export interface InitProgressReport { progress: number; timeElapsed: number; - cacheOnly: boolean; text: string; } export type InitProgressCallback = (report: InitProgressReport) => void; -/** - * Cache to store model related data. - */ -export class ArtifactCache implements ArtifactCacheTemplate { - private scope: string; - private cache?: Cache; - - constructor(scope: string) { - this.scope = scope; - } - - async fetchWithCache(url: string) { - const request = new Request(url); - if (this.cache === undefined) { - this.cache = await caches.open(this.scope); - } - let result = await this.cache.match(request); - if (result === undefined) { - await this.cache.add(request); - result = await this.cache.match(request); - } - if (result === undefined) { - throw Error("Cannot fetch " + url); - } - return result; - } - - async addToCache(url: string) { - const request = new Request(url); - if (this.cache === undefined) { - this.cache = await caches.open(this.scope); - } - const result = await this.cache.match(request); - if (result === undefined) { - await this.cache.add(request); - } - } - - async hasAllKeys(keys: string[]) { - if (this.cache === undefined) { - this.cache = await caches.open(this.scope); - } - return this.cache.keys() - .then(requests => requests.map(request => request.url)) - .then(cacheKeys => keys.every(key => cacheKeys.indexOf(key) !== -1)) - .catch(err => false); - } - - async deleteInCache(url: string) { - if (this.cache === undefined) { - this.cache = await caches.open(this.scope); - } - const result = await this.cache.delete(url); - return result; - } -} - /** * TVM runtime instance. * @@ -1500,21 +1431,26 @@ export class Instance implements Disposable { * @param ndarrayCacheUrl The cache url. * @param device The device to be fetched to. * @param cacheScope The scope identifier of the cache + * @param cacheType The type of the cache: "cache" or "indexedDB" * @returns The meta data */ async fetchNDArrayCache( ndarrayCacheUrl: string, device: DLDevice, - cacheScope = "tvmjs" + cacheScope = "tvmjs", + cacheType = "cache" ): Promise { - const artifactCache = new ArtifactCache(cacheScope); - const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; - const result = await artifactCache.fetchWithCache(jsonUrl); - - let list; - if (result instanceof Response) { - list = await result.json(); + let artifactCache: ArtifactCacheTemplate; + if (cacheType === undefined || cacheType.toLowerCase() === "cache") { + artifactCache = new ArtifactCache(cacheScope); + } else if (cacheType.toLowerCase() == "indexeddb") { + artifactCache = new ArtifactIndexedDBCache(cacheScope); + } else { + console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); + artifactCache = new ArtifactCache(cacheScope); } + const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + const list = await artifactCache.fetchWithCache(jsonUrl, "json"); await this.fetchNDArrayCacheInternal( ndarrayCacheUrl, list["records"] as Array, device, artifactCache); @@ -1538,7 +1474,6 @@ export class Instance implements Disposable { ) { const perf = compact.getPerformance(); const tstart = perf.now(); - let totalBytes = 0; for (let i = 0; i < list.length; ++i) { totalBytes += list[i].nbytes; @@ -1547,15 +1482,14 @@ export class Instance implements Disposable { let fetchedShards = 0; let timeElapsed = 0; - const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)) + const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); + // `loading`: we have finished downloading (or already cacheOnly) and are loading onto WebGPU const reportCallback = (iter: number, loading = false) => { // report for (let j = 0; j < this.initProgressCallback.length; ++j) { let text: string; if (loading) { - text = "Finished fetching params, loading onto WebGPU."; - } else if (cacheOnly) { text = "Loading model from cache[" + iter + "/" + list.length + "]: "; text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. " text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " @@ -1571,7 +1505,6 @@ export class Instance implements Disposable { this.initProgressCallback[j]({ progress: fetchedBytes / totalBytes, timeElapsed: timeElapsed, - cacheOnly: cacheOnly, text: text }); } @@ -1581,7 +1514,6 @@ export class Instance implements Disposable { this.initProgressCallback[j]({ progress: fetchedBytes / totalBytes, timeElapsed: 0, - cacheOnly: cacheOnly, text: "Start to fetch params", }); } @@ -1593,25 +1525,26 @@ export class Instance implements Disposable { const shard = list[i]; const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; try { - await artifactCache.addToCache(dataUrl); + await artifactCache.addToCache(dataUrl, "arraybuffer"); } catch (err) { this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); throw err; } timeElapsed = Math.ceil((perf.now() - tstart) / 1000); fetchedBytes += shard.nbytes; - reportCallback(fetchedShards++); + reportCallback(fetchedShards++, /*loading=*/false); } } // We launch 4 parallel for loops to limit the max concurrency to 4 download - const loopSize = Math.floor(list.length / 4); - await Promise.all([ - downloadCache(0, loopSize), - downloadCache(loopSize, 2 * loopSize), - downloadCache(2 * loopSize, 3 * loopSize), - downloadCache(3 * loopSize, list.length) - ]); - reportCallback(list.length, /*loading=*/true); + if (!cacheOnly) { + const loopSize = Math.floor(list.length / 4); + await Promise.all([ + downloadCache(0, loopSize), + downloadCache(loopSize, 2 * loopSize), + downloadCache(2 * loopSize, 3 * loopSize), + downloadCache(3 * loopSize, list.length) + ]); + } // Then iteratively, load the shard from cache for (let i = 0; i < list.length; ++i) { @@ -1619,7 +1552,7 @@ export class Instance implements Disposable { const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; let buffer; try { - buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer(); + buffer = await artifactCache.fetchWithCache(dataUrl, "arraybuffer"); } catch (err) { this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); throw err; @@ -1661,6 +1594,7 @@ export class Instance implements Disposable { throw err; } } + reportCallback(i + 1, /*loading=*/true); } } @@ -2118,7 +2052,6 @@ export class Instance implements Disposable { }).then(() => { finishCounter += 1; const tend = perf.now(); - const timeReportGap = 1000; // skip report if gap is smaller than 1000 if ((tend - tlastReport) < 1000 && finishCounter != fmapEntries.length) { return; @@ -2134,7 +2067,6 @@ export class Instance implements Disposable { this.initProgressCallback[j]({ progress: progress, timeElapsed: timeElapsed, - cacheOnly: false, text: text }); } @@ -2583,47 +2515,3 @@ export function instantiate( } ); } - -export async function hasNDArrayInCache( - ndarrayCacheUrl: string, - cacheScope = "tvmjs" -): Promise { - const artifactCache = new ArtifactCache(cacheScope); - const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; - const hasJsonUrlInCache = await artifactCache.hasAllKeys([jsonUrl]); - if (!hasJsonUrlInCache) { - return false; - } - const result = await artifactCache.fetchWithCache(jsonUrl); - let list; - if (result instanceof Response) { - list = await result.json(); - } - list = list["records"] as Array; - return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); -} - -/** - * Given cacheUrl, search up items to delete based on cacheUrl/ndarray-cache.json - * - * @param cacheUrl - * @param cacheScope - */ -export async function deleteNDArrayCache( - cacheUrl: string, - cacheScope = "tvmjs" -) { - const artifactCache = new ArtifactCache(cacheScope); - const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href; - const result = await artifactCache.fetchWithCache(jsonUrl); - let list; - if (result instanceof Response) { - list = await result.json(); - } - const arrayentry = list["records"] as Array; - const processShard = async (i: number) => { - const dataUrl = new URL(arrayentry[i].dataPath, cacheUrl).href; - await artifactCache.deleteInCache(dataUrl); - } - await Promise.all(arrayentry.map((_, index) => processShard(index))); -}