Skip to content

Commit

Permalink
feat: add support for text data & tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed Jul 6, 2023
1 parent e1b67f1 commit 9c96d71
Show file tree
Hide file tree
Showing 39 changed files with 468 additions and 188 deletions.
56 changes: 53 additions & 3 deletions discojs/discojs-core/src/dataset/data/data.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { Task } from '../..'
import { tf, Task } from '../..'
import { Dataset } from '../dataset'
import { PreprocessingFunction } from './preprocessing/base'

import { List } from 'immutable'

export abstract class Data {
public abstract readonly availablePreprocessing: List<PreprocessingFunction>

protected constructor (
public readonly dataset: Dataset,
public readonly task: Task,
Expand All @@ -15,7 +20,52 @@ export abstract class Data {
throw new Error('abstract')
}

abstract batch (): Data
// Callable abstract method instead of constructor
protected abstract create (dataset: Dataset, task: Task, size?: number): Data

batch (): Data {
return this.create(this.batchedDataset, this.task, this.size)
}

get batchedDataset (): Dataset {
const batchSize = this.task.trainingInformation.batchSize
return batchSize === undefined
? this.dataset
: this.dataset.batch(batchSize)
}

preprocess (): Data {
return this.create(this.preprocessedDataset, this.task, this.size)
}

get preprocessing (): (entry: tf.TensorContainer) => tf.TensorContainer {
const params = this.task.trainingInformation
const taskPreprocessing = params.preprocessingFunctions

if (
taskPreprocessing === undefined ||
taskPreprocessing.length === 0 ||
this.availablePreprocessing.size === 0
) {
return (x) => x
}

const applyPreprocessing = this.availablePreprocessing
.filter((e) => e.type in taskPreprocessing)
.map((e) => e.apply)

if (applyPreprocessing.size === 0) {
return (x) => x
}

abstract preprocess (): Promise<Data>
const preprocessingChain = applyPreprocessing
.reduce((acc: (x: tf.TensorContainer, task: Task) => tf.TensorContainer, fn) =>
(x: tf.TensorContainer, task: Task) => fn(acc(x, this.task), this.task))

return (x: tf.TensorContainer) => preprocessingChain(x, this.task)
}

get preprocessedDataset (): Dataset {
return this.dataset.map(this.preprocessing)
}
}
18 changes: 5 additions & 13 deletions discojs/discojs-core/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { tf, Task } from '../..'
import { getPreprocessImage, ImagePreprocessing } from './preprocessing'
import { Dataset } from '../dataset'
import { Data } from './data'
import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing'

export class ImageData extends Data {
public readonly availablePreprocessing = IMAGE_PREPROCESSING

static async init (
dataset: Dataset,
task: Task,
Expand Down Expand Up @@ -42,17 +44,7 @@ export class ImageData extends Data {
return new ImageData(dataset, task, size)
}

batch (): Data {
const batchSize = this.task.trainingInformation.batchSize
const newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize)

return new ImageData(newDataset, this.task, this.size)
}

async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessImage = getPreprocessImage(this.task)
newDataset = newDataset.map((x: tf.TensorContainer) => preprocessImage(x))
return new ImageData(newDataset, this.task, this.size)
protected create (dataset: Dataset, task: Task, size: number): ImageData {
return new ImageData(dataset, task, size)
}
}
1 change: 1 addition & 0 deletions discojs/discojs-core/src/dataset/data/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ export { DataSplit } from './data_split'
export { Data } from './data'
export { ImageData } from './image_data'
export { TabularData } from './tabular_data'
export { TextData } from './text_data'
export { ImagePreprocessing, TabularPreprocessing } from './preprocessing'
77 changes: 0 additions & 77 deletions discojs/discojs-core/src/dataset/data/preprocessing.ts

This file was deleted.

11 changes: 11 additions & 0 deletions discojs/discojs-core/src/dataset/data/preprocessing/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { tf, Task } from '../../..'
import { ImagePreprocessing } from './image_preprocessing'
import { TabularPreprocessing } from './tabular_preprocessing'
import { TextPreprocessing } from './text_preprocessing'

export type Preprocessing = ImagePreprocessing | TextPreprocessing | TabularPreprocessing

export interface PreprocessingFunction {
type: Preprocessing
apply: (x: tf.TensorContainer, task: Task) => tf.TensorContainer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Task, tf } from '../../..'
import { PreprocessingFunction } from './base'

import { List } from 'immutable'

export enum ImagePreprocessing {
Resize,
Normalize
}

interface ImageEntry extends tf.TensorContainerObject {
xs: tf.Tensor3D | tf.Tensor4D
ys: tf.Tensor1D | number | undefined
}

const resize: PreprocessingFunction = {
type: ImagePreprocessing.Resize,
apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => {
const { xs, ys } = entry as ImageEntry
const params = task.trainingInformation
return {
xs: params.IMAGE_W !== undefined && params.IMAGE_H !== undefined
? xs.resizeBilinear([params.IMAGE_H, params.IMAGE_W])
: xs,
ys
}
}
}

const normalize: PreprocessingFunction = {
type: ImagePreprocessing.Normalize,
apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => {
const { xs, ys } = entry as ImageEntry
return {
xs: xs.div(tf.scalar(255)),
ys
}
}
}

export const AVAILABLE_PREPROCESSING = List.of(
resize,
normalize
).sortBy((e) => e.type)
4 changes: 4 additions & 0 deletions discojs/discojs-core/src/dataset/data/preprocessing/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export { Preprocessing, PreprocessingFunction } from './base'
export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing'
export { AVAILABLE_PREPROCESSING as TABULAR_PREPROCESSING, TabularPreprocessing } from './tabular_preprocessing'
export { AVAILABLE_PREPROCESSING as TEXT_PREPROCESSING, TextPreprocessing } from './text_preprocessing'
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { List } from 'immutable'
import { PreprocessingFunction } from './base'

export enum TabularPreprocessing {
Sanitize,
Normalize
}

export const AVAILABLE_PREPROCESSING = List<PreprocessingFunction>()
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { Task, tf } from '../../..'
import { PreprocessingFunction } from './base'

import GPT3Tokenizer from 'gpt3-tokenizer'
import { List } from 'immutable'

export enum TextPreprocessing {
Tokenize,
Padding
}

interface TextEntry extends tf.TensorContainerObject {
xs: string[]
ys: number[]
}

interface TokenizedEntry extends tf.TensorContainerObject {
xs: tf.Tensor1D
ys: tf.Tensor1D
}

const gpt3Tokenizer = new GPT3Tokenizer({ type: 'gpt3' })

const padding: PreprocessingFunction = {
type: TextPreprocessing.Padding,
apply: (x: tf.TensorContainer, task: Task) => {
const { xs, ys } = x as TokenizedEntry
// TODO: add to task definition
const maxLength = 64
if (maxLength === undefined) {
return { xs, ys }
}
return {
xs: xs
.pad([[0, Math.max(0, maxLength - xs.size)]])
.slice([0], [maxLength]),
ys
}
}
}

const tokenize: PreprocessingFunction = {
type: TextPreprocessing.Tokenize,
apply: (x: tf.TensorContainer, task: Task) => {
const { xs, ys } = x as TextEntry
const params = task.trainingInformation
// TODO: add to task definition
const tokenizer = (params as unknown as any).tokenizer

let tokenized: number[]
if (tokenizer === undefined) {
tokenized = gpt3Tokenizer.encode(xs[0]).bpe
} else {
throw new Error('tokenizer not implemented')
}

return {
xs: tf.tensor(tokenized),
ys: tf.tensor(ys)
}
}
}

export const AVAILABLE_PREPROCESSING = List.of(
tokenize,
padding
)
20 changes: 6 additions & 14 deletions discojs/discojs-core/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import { Task } from '../..'
import { getPreprocessTabular } from './preprocessing'
import { Dataset } from '../dataset'
import { Data } from './data'
import { TABULAR_PREPROCESSING } from './preprocessing'

export class TabularData extends Data {
public readonly availablePreprocessing = TABULAR_PREPROCESSING

static async init (
dataset: Dataset,
task: Task,
size?: number
): Promise<Data> {
): Promise<TabularData> {
// Force the check of the data column format (among other things) before proceeding
// to training, for better error handling. An incorrectly formatted line might still
// cause an error during training, because of the lazy aspect of the dataset; we only
Expand All @@ -22,17 +24,7 @@ export class TabularData extends Data {
return new TabularData(dataset, task, size)
}

batch (): Data {
const batchSize = this.task.trainingInformation.batchSize
const newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize)

return new TabularData(newDataset, this.task, this.size)
}

async preprocess (): Promise<Data> {
let newDataset = this.dataset
const preprocessTabular = getPreprocessTabular(this.task)
newDataset = await preprocessTabular(newDataset)
return new TabularData(newDataset, this.task, this.size)
protected create (dataset: Dataset, task: Task, size: number): TabularData {
return new TabularData(dataset, task, size)
}
}
20 changes: 20 additions & 0 deletions discojs/discojs-core/src/dataset/data/text_data.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Task } from '../..'
import { Dataset } from '../dataset'
import { Data } from './data'
import { TEXT_PREPROCESSING } from './preprocessing'

export class TextData extends Data {
public readonly availablePreprocessing = TEXT_PREPROCESSING

static async init (
dataset: Dataset,
task: Task,
size?: number
): Promise<TextData> {
return new TextData(dataset, task, size)
}

protected create (dataset: Dataset, task: Task, size?: number): TextData {
return new TextData(dataset, task, size)
}
}
Loading

0 comments on commit 9c96d71

Please sign in to comment.