From 00ccb8b06d94c023f09d14d9bbef93aea2ad1cfd Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Sun, 15 Sep 2024 09:42:42 -0400 Subject: [PATCH] new pivot sort strategy --- src/deepscatter.ts | 2 +- src/selection.ts | 354 ++++++++++++++++++++++++++++++++-------- src/utilityFunctions.ts | 49 +++--- tests/dataset.spec.js | 33 +++- tests/datasetHelpers.js | 4 + 5 files changed, 345 insertions(+), 97 deletions(-) diff --git a/src/deepscatter.ts b/src/deepscatter.ts index 793123c0a..0498311ba 100644 --- a/src/deepscatter.ts +++ b/src/deepscatter.ts @@ -1,5 +1,5 @@ export { Scatterplot } from './scatterplot'; -export { Bitmask, DataSelection } from './selection'; +export { Bitmask, DataSelection, SortedDataSelection } from './selection'; export { Deeptable } from './Deeptable'; export { LabelMaker } from './label_rendering'; export { dictionaryFromArrays } from './utilityFunctions'; diff --git a/src/selection.ts b/src/selection.ts index c6298be00..36edb3ef0 100644 --- a/src/selection.ts +++ b/src/selection.ts @@ -4,8 +4,15 @@ import { Scatterplot } from './scatterplot'; import { Tile } from './tile'; import { getTileFromRow } from './tixrixqid'; import type * as DS from './types'; -import { Bool, Float, StructRowProxy, Utf8, Vector, makeData } from 'apache-arrow'; -import { range } from 'd3-array'; +import { + Bool, + Struct, + StructRowProxy, + Utf8, + Vector, + makeData, +} from 'apache-arrow'; +import { bisectLeft, bisectRight, range } from 'd3-array'; interface SelectParams { name: string; useNameCache?: boolean; // If true and a selection with that name already exists, use it and ignore all passed parameters. Otherwise, throw an error. @@ -157,8 +164,6 @@ function isFunctionSelectParam( return (params as FunctionSelectParams).tileFunction !== undefined; } - - /** * A Bitmask is used to hold boolean filters across a single record batch. * It it used internally to manage selections, and can also be useful @@ -215,10 +220,10 @@ export class Bitmask { const result: number[] = []; for (let chunk = 0; chunk < this.length / 8; chunk++) { const b = this.mask[chunk]; - // THese are sparse, so we can usually + // THese are sparse, so we can usually skip the whole byte. if (b !== 0) { for (let bit = 0; bit < 8; bit++) { - if (1 << bit !== 0) { + if ((b & (1 << bit)) !== 0) { result.push(chunk * 8 + bit); } } @@ -298,6 +303,14 @@ export class Bitmask { } } +type SelectionSortInfo = { + indices: Uint16Array; + // Note that we can't sort by strings. + values: Float64Array; + start: number; + end: number; +}; + class SelectionTile { // The deepscatter Tile object. public tile: Tile; @@ -306,24 +319,22 @@ class SelectionTile { // used to access numbers by index. public _matchCount: number; - // An access order into the tile. If defined array of the same - // length as matchCount whose values are indices into the - // underlying tile. - // If not defined (which saves a great deal of memory) - // we just inspect the bitmask directly. - public indices?: Uint16Array; - - // The current point in the indices at which a cursor points. - // Only used for sorted selections. - public cursor?: number = 0; + public sorts: Record = {}; public bitmask: Vector; // Created with a tile and the set of matches. // If building from another SelectionTile, may also pass // the matchCount. - constructor({ tile, arrowBitmask, matchCount }: - { tile: Tile, arrowBitmask: Vector, matchCount?: number }) { + constructor({ + tile, + arrowBitmask, + matchCount, + }: { + tile: Tile; + arrowBitmask: Vector; + matchCount?: number; + }) { this.tile = tile; this.bitmask = arrowBitmask; if (matchCount !== undefined) { @@ -333,7 +344,7 @@ class SelectionTile { get matchCount(): number { if (this._matchCount) { - return this._matchCount + return this._matchCount; } let matchCount = 0; const { bitmask } = this; @@ -346,18 +357,32 @@ class SelectionTile { return this._matchCount; } - setCursorBelow(x: number, field: string) { - if (this.indices === undefined) { - throw new Error("Can't operate on unsorted array") + addSort( + key: string, + getter: (row: StructRowProxy) => number, + order: 'ascending' | 'descending', + ) { + const { bitmask } = this; + const indices = Bitmask.from_arrow(bitmask).which(); + const pairs: [number, number][] = new Array(indices.length); + for (let i = 0; i < indices.length; i++) { + const v = getter(this.tile.record_batch.get(indices[i])); + pairs[i] = [v, indices[i]]; } - const vals = await this.tile.record_batch.getChild(field) as Vector; - - let i = Math.floor(this.matchCount / 2); - let stride = i; - - + // Sort according to the specified order + pairs.sort((a, b) => (order === 'ascending' ? a[0] - b[0] : b[0] - a[0])); + const values = new Float64Array(indices.length); + for (let i = 0; i < indices.length; i++) { + indices[i] = pairs[i][1]; + values[i] = pairs[i][0]; + } + this.sorts[key] = { + indices, + values, + start: 0, + end: indices.length, + }; } - } /** @@ -429,6 +454,11 @@ export class DataSelection { type?: string; composition: null | Composition = null; private events: { [key: string]: Array<(args) => void> } = {}; + public params: + | IdSelectParams + | BooleanColumnParams + | FunctionSelectParams + | CompositeSelectParams; constructor( deeptable: Deeptable, @@ -443,7 +473,7 @@ export class DataSelection { throw new Error("Can't create a selection without a deeptable"); } this.name = params.name; - let markReady = function () { }; + let markReady = function () {}; this.ready = new Promise((resolve) => { markReady = resolve; }); @@ -468,6 +498,7 @@ export class DataSelection { return bitmask.to_arrow(); }).then(markReady); } + this.params = params; } /** @@ -483,7 +514,7 @@ export class DataSelection { this.events[event].push(listener); } - private dispatch(event: string, args: unknown): void { + protected dispatch(event: string, args: unknown): void { if (this.events[event]) { this.events[event].forEach((listener) => listener(args)); } @@ -507,7 +538,7 @@ export class DataSelection { // triggers creation of the deeptable column as a side-effect. return tile.get_column(this.name); }), - ).then(() => { }); + ).then(() => {}); } /** @@ -758,8 +789,8 @@ export class DataSelection { await tile.populateManifest(); const t = new SelectionTile({ arrowBitmask: array, - tile - }) + tile, + }); this.tiles.push(t); this.selectionSize += t.matchCount; this.evaluationSetSize += tile.manifest.nPoints; @@ -975,13 +1006,10 @@ function stringmatcher(field: string, matches: string[]) { export class SortedDataSelection extends DataSelection { public tiles: SelectionTile[] = []; public neededFields: string[]; - public comparisonGetter: (a: StructRowProxy) => number | Date; + public comparisonGetter: (a: StructRowProxy) => number; public order: 'ascending' | 'descending'; + public key: string; - // The current point in the selection at which we're focused. - public cursor: number = 0; - // A list of all the tiles in the selection, each with a cursor, in order. - private sortStack: [SelectionTile[]] = null; constructor( deeptable: Deeptable, params: @@ -989,62 +1017,248 @@ export class SortedDataSelection extends DataSelection { | BooleanColumnParams | FunctionSelectParams | CompositeSelectParams, - sortOperation: (a: StructRowProxy) => number | Date, + sortOperation: (a: StructRowProxy) => number, neededFields: string[], - tiles: SelectionTile[] = [], order: 'ascending' | 'descending' = 'ascending', + key?: string, ) { super(deeptable, params); - this.tiles = tiles; this.neededFields = neededFields; this.comparisonGetter = sortOperation; this.order = order; + this.key = key || Math.random().toFixed(10).slice(2); } - private sortTilesToNthPoint(n: number) { - // Sort the stack in ascending order put - this.sortStack.sort((a, b) => ) - let currentVal = this.sortStack[0] - let numberBelowThis = - } + // To create a sorted selection from a selection that already + // has some tiles loaded on it, we need to + // go back and create and add all the stats that would have been + // calculated at wrapWithSelectionMetadata. + static async fromSelection( + sel: DataSelection, + neededFields: string[], + sortOperation: (a: StructRowProxy) => number, + order: 'ascending' | 'descending' = 'ascending', + tKey: string | undefined = undefined, + name: string | undefined = undefined, + ): Promise { + const key = tKey || Math.random().toFixed(10).slice(2); + const newer = new SortedDataSelection( + sel.deeptable, + { + name: Math.random().toFixed(10).slice(2), + tileFunction: async (tile: Tile): Promise> => + tile.get_column(sel.name), + }, + sortOperation, + neededFields, + order, + key, + ); + // Ensure that all the fields we need are ready. + const withSort = sel.tiles.map( + async (tile: SelectionTile): Promise => { + await Promise.all(neededFields.map((f) => tile.tile.get_column(f))); + tile.addSort(key, sortOperation, order); + return tile; + }, + ); + newer.tiles = await Promise.all(withSort); + newer.selectionSize = newer.tiles.reduce((sum, t) => sum + t.matchCount, 0); + return newer; + } + // In addition to the regular things, we also need to add sort fields. protected wrapWithSelectionMetadata( functionToApply: DS.BoolTransformation, ): DS.BoolTransformation { - // When we wrap a filter selection, - const wrappedFunction = super.wrapWithSelectionMetadata(functionToApply); return async (tile: Tile) => { - const tileWeJustAdded = this.tiles[this.tiles.length - 1]; + const array = await functionToApply(tile); + + await tile.populateManifest(); + // Ensure that all the fields needed for the sort operation are present. - const [selection] = await Promise.all([ - wrappedFunction(tile), - ...this.neededFields.map((field) => - // Return null to avoid wasting memory on them. - tile.get_column(field).then(() => null), - ), - ]); - const indices = Bitmask.from_arrow(selection).which(); - tileWeJustAdded.indices = indices.sort((a, b) => { - return this.comparisonGetter(tile.record_batch.get(a)) > - this.comparisonGetter(tile.record_batch.get(b)) - ? -1 - : 1; - }); - return selection; + await Promise.all(this.neededFields.map((f) => tile.get_column(f))); + + // Store the indices and values in the tile + + let ix = this.tiles.findIndex((having) => having.tile === tile); + let t: SelectionTile; + if (ix !== -1) { + t = this.tiles[ix]; + t.addSort(this.key, this.comparisonGetter, this.order); + } else { + t = new SelectionTile({ + arrowBitmask: array, + tile, + }); + t.addSort(this.key, this.comparisonGetter, this.order); + this.selectionSize += t.matchCount; + this.evaluationSetSize += tile.manifest.nPoints; + this.tiles.push(t); + } + + this.dispatch('tile loaded', tile); + return array; }; } + /** + * Returns the k-th element in the sorted selection. + * This implementation uses Quickselect with a pivot selected from actual data. + */ + get(k: number): StructRowProxy | undefined { + if (k < 0 || k >= this.selectionSize) { + console.error('Index out of bounds'); + return undefined; + } + + // Adjust k based on the order + const targetIndex = + this.order === 'ascending' ? k : this.selectionSize - k - 1; + + // Implement Quickselect over the combined data + return quickSelect(targetIndex, this.tiles, this.key); + } + + // Given a point, returns cursor number that would select it in this selection + which(row: StructRowProxy) {} + *yieldSorted(start = undefined, direction = 'up') { if (start !== undefined) { this.cursor = start; - this.sortStack = null; } } +} - // get(i: number | undefined = undefined): StructRowProxy { - // // - // } +interface QuickSortTile { + tile: Tile; + sorts: Record; +} +function quickSelect( + k: number, + tiles: QuickSortTile[], + key: string, +): StructRowProxy | undefined { + // Recalculate size based on the current tiles + const size = tiles.reduce( + (acc, t) => acc + (t.sorts[key].end - t.sorts[key].start), + 0, + ); + + if (size === 1) { + for (const t of tiles) { + const { indices, start, end } = t.sorts[key]; + if (end - start > 0) { + const recordIndex = indices[start]; + return t.tile.record_batch.get(recordIndex); + } + } + return undefined; + } + + // Select a random pivot from actual data + const pivot = randomPivotFromData(tiles, key); + + let countLess = 0; + let countEqual = 0; + let countGreater = 0; + + const lessTiles: QuickSortTile[] = []; + const equalTiles: QuickSortTile[] = []; + const greaterTiles: QuickSortTile[] = []; + + for (const t of tiles) { + const { values, indices, start, end } = t.sorts[key]; + + const left = bisectLeft(values, pivot, start, end); + const right = bisectRight(values, pivot, start, end); + + const lessSize = left - start; + const equalSize = right - left; + const greaterSize = end - right; + + if (lessSize > 0) { + lessTiles.push({ + tile: t.tile, + sorts: { + [key]: { indices, values, start, end: left }, + }, + }); + countLess += lessSize; + } + + if (equalSize > 0) { + equalTiles.push({ + tile: t.tile, + sorts: { + [key]: { indices, values, start: left, end: right }, + }, + }); + countEqual += equalSize; + } + + if (greaterSize > 0) { + greaterTiles.push({ + tile: t.tile, + sorts: { + [key]: { indices, values, start: right, end }, + }, + }); + countGreater += greaterSize; + } + } + + // Verify that counts sum up correctly + if (countLess + countEqual + countGreater !== size) { + throw new Error('Counts do not sum up to size'); + } + + if (k < countLess) { + return quickSelect(k, lessTiles, key); + } else if (k < countLess + countEqual) { + const indexInEqual = k - countLess; + return selectInEqualTiles(indexInEqual, equalTiles, key); + } else { + const newK = k - (countLess + countEqual); + return quickSelect(newK, greaterTiles, key); + } } +function selectInEqualTiles( + indexInEqual: number, + tiles: QuickSortTile[], + key: string, +): StructRowProxy | undefined { + let count = 0; + for (const t of tiles) { + const { indices, start, end } = t.sorts[key]; + const numValues = end - start; + if (indexInEqual < count + numValues) { + const idxInTile = start + (indexInEqual - count); + const recordIndex = indices[idxInTile]; + return t.tile.record_batch.get(recordIndex); + } + count += numValues; + } + return undefined; +} +function randomPivotFromData(tiles: QuickSortTile[], key: string): number { + const totalSize = tiles.reduce( + (acc, t) => acc + (t.sorts[key].end - t.sorts[key].start), + 0, + ); + const randomIndex = Math.floor(Math.random() * totalSize); + let count = 0; + for (const t of tiles) { + const { values, start, end } = t.sorts[key]; + const numValues = end - start; + if (randomIndex < count + numValues) { + const idxInTile = start + (randomIndex - count); + return values[idxInTile]; + } + count += numValues; + } + throw new Error('Got lost in randomPivotFromData'); +} diff --git a/src/utilityFunctions.ts b/src/utilityFunctions.ts index f92848d7c..a97aa7b37 100644 --- a/src/utilityFunctions.ts +++ b/src/utilityFunctions.ts @@ -9,8 +9,11 @@ import { makeVector, } from 'apache-arrow'; -type IndicesType = Int8Array | Int16Array | Int32Array; -type DictionaryType = Dictionary; +type ArrayToArrayMap = { + Int8Array: Int8; + Int16Array: Int16; + Int32Array: Int32; +}; // We need to keep track of the current dictionary number // to avoid conflicts. I start these at 07540, the zip code @@ -19,13 +22,13 @@ type DictionaryType = Dictionary; let currentDictNumber = 7540; // Function overloads to make this curryable. -export function dictionaryFromArrays( +export function dictionaryFromArrays( labels: string[], -): (indices: IndicesType) => Vector; -export function dictionaryFromArrays( +): (indices: T) => Vector>; +export function dictionaryFromArrays( labels: string[], - indices: IndicesType, -): Vector; + indices: T, +): Vector>; /** * Create a dictionary from labels and integer indices. @@ -33,45 +36,41 @@ export function dictionaryFromArrays( * dictionaries--this method is *strongly* recommended if you don't * want things to be really slow. */ -export function dictionaryFromArrays( +export function dictionaryFromArrays( labels: string[], - indices?: IndicesType, -): Vector | ((indices: IndicesType) => Vector) { + indices?: T, +): + | Vector> + | ((indices: T) => Vector>) { // Run vectorFromArray only once to create labelsArrow. const labelsArrow: Vector = vectorFromArray(labels, new Utf8()); // Return a function that captures labelsArrow. if (indices === undefined) { - return (indices: IndicesType) => - createDictionaryWithVector(labelsArrow, indices); + return (indices: T) => createDictionaryWithVector(labelsArrow, indices); } return createDictionaryWithVector(labelsArrow, indices); } -function createDictionaryWithVector( +function createDictionaryWithVector( labelsArrow: Vector, - indices: IndicesType, -): Vector { - let t; + indices: T, +): Vector> { + let t: ArrayToArrayMap[T]; if (indices[Symbol.toStringTag] === `Int8Array`) { - t = new Int8(); + t = new Int8() as ArrayToArrayMap[T]; } else if (indices[Symbol.toStringTag] === `Int16Array`) { - t = new Int16(); + t = new Int16() as ArrayToArrayMap[T]; } else if (indices[Symbol.toStringTag] === `Int32Array`) { - t = new Int32(); + t = new Int32() as ArrayToArrayMap[T]; } else { throw new Error( 'values must be an array of signed integers, 32 bit or smaller.', ); } - const type = new Dictionary( - labelsArrow.type, - t, - currentDictNumber++, - false, - ) as Dictionary; + const type = new Dictionary(labelsArrow.type, t, currentDictNumber++, false); const returnval = makeVector({ type, length: indices.length, diff --git a/tests/dataset.spec.js b/tests/dataset.spec.js index 4ca700bb6..09a1dd76c 100644 --- a/tests/dataset.spec.js +++ b/tests/dataset.spec.js @@ -1,4 +1,9 @@ -import { Deeptable, DataSelection, Bitmask } from '../dist/deepscatter.js'; +import { + Deeptable, + DataSelection, + SortedDataSelection, + Bitmask, +} from '../dist/deepscatter.js'; import { Table, vectorFromArray, Utf8 } from 'apache-arrow'; import { test } from 'uvu'; import * as assert from 'uvu/assert'; @@ -94,4 +99,30 @@ test('Test composition of selections', async () => { console.log(v); }); +test('Test sorting of selections', async () => { + const dataset = createIntegerDataset(); + await dataset.root_tile.preprocessRootTileInfo(); + const selectEvens = new DataSelection(dataset, { + name: 'twos2', + tileFunction: selectFunctionForFactorsOf(2), + }); + await selectEvens.applyToAllTiles(); + const sorted = await SortedDataSelection.fromSelection( + selectEvens, + ['random'], + ({ random }) => random, + ); + await sorted.applyToAllTiles(); + // This should be 8192. + const bottom = sorted.get(0); + assert.ok(bottom.random < 0.01); + + const foo = sorted.get(sorted.selectionSize - 1); + assert.ok(foo.random > 0.99); + + const mid = sorted.get(Math.floor(sorted.selectionSize / 2)); + assert.ok(mid.random > 0.45); + assert.ok(mid.random < 0.55); +}); + test.run(); diff --git a/tests/datasetHelpers.js b/tests/datasetHelpers.js index bf433f3ec..62dfcfca6 100644 --- a/tests/datasetHelpers.js +++ b/tests/datasetHelpers.js @@ -21,6 +21,8 @@ function make_batch(start = 0, length = 65536, batch_number_here = 0) { let integers = new Int32Array(length); let ix = new Uint32Array(length); let batch_id = new Float32Array(length).fill(batch_number_here); + let randoms = new Float32Array(length); + for (let i = start; i < start + length; i++) { ix[i - start] = i; let x_ = 0; @@ -39,6 +41,7 @@ function make_batch(start = 0, length = 65536, batch_number_here = 0) { x[i - start] = x_; y[i - start] = y_; integers[i - start] = i; + randoms[i - start] = Math.random(); } function num_to_string(num) { @@ -51,6 +54,7 @@ function make_batch(start = 0, length = 65536, batch_number_here = 0) { _id: vectorFromArray(vs, new Utf8()), integers: vectorFromArray(integers), batch_id: vectorFromArray(batch_id), + random: vectorFromArray(randoms), }); }