diff --git a/web/src/index.ts b/web/src/index.ts index 9099d8f37347..edc695978f50 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -26,7 +26,7 @@ export { } from "./runtime"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; -export { wasmPath } from "./support"; +export { wasmPath, LinearCongruentialGenerator } from "./support"; export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu"; export { assert } from "./support"; export { createPolyfillWASI } from "./compact"; diff --git a/web/src/runtime.ts b/web/src/runtime.ts index ea022d1b3e9d..9142571b9e4a 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -23,7 +23,7 @@ import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; -import { assert, StringToUint8Array } from "./support"; +import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support"; import { Environment } from "./environment"; import { AsyncifyHandler } from "./asyncify"; import { FunctionInfo, WebGPUContext } from "./webgpu"; @@ -1079,6 +1079,7 @@ export class Instance implements Disposable { private ctx: RuntimeContext; private asyncifyHandler: AsyncifyHandler; private initProgressCallback: Array = []; + private rng: LinearCongruentialGenerator; /** * Internal function(registered by the runtime) @@ -1131,6 +1132,7 @@ export class Instance implements Disposable { ); this.registerEnvGlobalPackedFuncs(); this.registerObjectFactoryFuncs(); + this.rng = new LinearCongruentialGenerator(); } /** @@ -1811,11 +1813,18 @@ export class Instance implements Disposable { const scale = high - low; const input = new Float32Array(size); for (let i = 0; i < input.length; ++i) { - input[i] = low + Math.random() * scale; + input[i] = low + this.rng.randomFloat() * scale; } return ret.copyFrom(input); } + /** + * Set the seed of the internal LinearCongruentialGenerator. + */ + setSeed(seed: number): void { + this.rng.setSeed(seed); + } + /** * Sample index via top-p sampling. * @@ -1825,7 +1834,7 @@ export class Instance implements Disposable { * @returns The sampled index. */ sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number { - return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random()); + return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, this.rng.randomFloat()); } /** @@ -1836,7 +1845,7 @@ export class Instance implements Disposable { * @returns The sampled index. */ sampleTopPFromProb(prob: NDArray, top_p: number): number { - return this.ctx.sampleTopPFromProb(prob, top_p, Math.random()); + return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat()); } /** diff --git a/web/src/support.ts b/web/src/support.ts index b03fa363cdce..2fa87ed291a2 100644 --- a/web/src/support.ts +++ b/web/src/support.ts @@ -74,3 +74,79 @@ export function assert(condition: boolean, msg?: string): asserts condition { export function wasmPath(): string { return __dirname + "/wasm"; } + +/** + * Linear congruential generator for random number generating that can be seeded. + * + * Follows the implementation of `include/tvm/support/random_engine.h`, which follows the + * sepcification in https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine. + * + * Note `Number.MAX_SAFE_INTEGER = 2^53 - 1`, and our intermediates are strictly less than 2^48. + */ + +export class LinearCongruentialGenerator { + readonly modulus: number; + readonly multiplier: number; + readonly increment: number; + // Always within the range (0, 2^32 - 1) non-inclusive; if 0, will forever generate 0. + private rand_state: number; + + /** + * Set modulus, multiplier, and increment. Initialize `rand_state` according to `Date.now()`. + */ + constructor() { + this.modulus = 2147483647; // 2^32 - 1 + this.multiplier = 48271; // between 2^15 and 2^16 + this.increment = 0; + this.setSeed(Date.now()); + } + + /** + * Sets `rand_state` after normalized with `modulus` to ensure that it is within range. + * @param seed Any integer. Used to set `rand_state` after normalized with `modulus`. + * + * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer. + */ + setSeed(seed: number) { + if (!Number.isInteger(seed)) { + throw new Error("Seed should be an integer."); + } + this.rand_state = seed % this.modulus; + if (this.rand_state == 0) { + this.rand_state = 1; + } + this.checkRandState(); + } + + /** + * Generate the next integer in the range (0, this.modulus) non-inclusive, updating `rand_state`. + * + * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer. + */ + nextInt(): number { + // `intermediate` is always < 2^48, hence less than `Number.MAX_SAFE_INTEGER` due to the + // invariants as commented in the constructor. + const intermediate = this.multiplier * this.rand_state + this.increment; + this.rand_state = intermediate % this.modulus; + this.checkRandState(); + return this.rand_state; + } + + /** + * Generates random float between (0, 1) non-inclusive, updating `rand_state`. + * + * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer. + */ + randomFloat(): number { + return this.nextInt() / this.modulus; + } + + private checkRandState(): void { + if (this.rand_state <= 0) { + throw new Error("Random state is unexpectedly not strictly positive."); + } + if (!Number.isInteger(this.rand_state)) { + throw new Error("Random state is unexpectedly not an integer."); + } + } +} diff --git a/web/tests/node/test_random_generator.js b/web/tests/node/test_random_generator.js new file mode 100644 index 000000000000..adc6635d0576 --- /dev/null +++ b/web/tests/node/test_random_generator.js @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ + +const tvmjs = require("../../dist"); + +test("Test coverage of [0,100] inclusive", () => { + const covered = Array(100); + const rng = new tvmjs.LinearCongruentialGenerator(); + for (let i = 0; i < 100000; i++) { + covered[rng.nextInt() % 100] = true; + } + const notCovered = []; + for (let i = 0; i < 100; i++) { + if (!covered[i]) { + notCovered.push(i); + } + } + expect(notCovered).toEqual([]); +}); + +test("Test whether the same seed make two RNGs generate same results", () => { + const rng1 = new tvmjs.LinearCongruentialGenerator(); + const rng2 = new tvmjs.LinearCongruentialGenerator(); + rng1.setSeed(42); + rng2.setSeed(42); + + for (let i = 0; i < 100; i++) { + expect(rng1.randomFloat()).toBeCloseTo(rng2.randomFloat()); + } +}); + +test("Test two RNGs with different seeds generate different results", () => { + const rng1 = new tvmjs.LinearCongruentialGenerator(); + const rng2 = new tvmjs.LinearCongruentialGenerator(); + rng1.setSeed(41); + rng2.setSeed(42); + let numSame = 0; + const numTest = 100; + + // Generate `numTest` random numbers, make sure not all are the same. + for (let i = 0; i < numTest; i++) { + if (rng1.nextInt() === rng2.nextInt()) { + numSame += 1; + } + } + expect(numSame < numTest).toBe(true); +}); + +test('Illegal argument to `setSeed()`', () => { + expect(() => { + const rng1 = new tvmjs.LinearCongruentialGenerator(); + rng1.setSeed(42.5); + }).toThrow("Seed should be an integer."); +});