-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for text data & tokenization
- Loading branch information
Showing
13 changed files
with
224 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,10 @@ | ||
import { tf, Task } from '../..' | ||
|
||
type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer | ||
|
||
export type Preprocessing = ImagePreprocessing | ||
|
||
export interface ImageTensorContainer extends tf.TensorContainerObject { | ||
xs: tf.Tensor3D | tf.Tensor4D | ||
ys: tf.Tensor1D | number | undefined | ||
} | ||
|
||
export enum ImagePreprocessing { | ||
Normalize = 'normalize', | ||
Resize = 'resize' | ||
Resize = 'resize', | ||
} | ||
|
||
export function getPreprocessImage (task: Task): PreprocessImage { | ||
const preprocessImage: PreprocessImage = (tensorContainer: tf.TensorContainer): tf.TensorContainer => { | ||
// TODO unsafe cast, tfjs does not provide the right interface | ||
const info = task.trainingInformation | ||
let { xs, ys } = tensorContainer as ImageTensorContainer | ||
if (info.preprocessingFunctions?.includes(ImagePreprocessing.Normalize)) { | ||
xs = xs.div(tf.scalar(255)) | ||
} | ||
if (info.preprocessingFunctions?.includes(ImagePreprocessing.Resize) && | ||
info.IMAGE_H !== undefined && | ||
info.IMAGE_W !== undefined) { | ||
xs = tf.image.resizeBilinear(xs, [ | ||
info.IMAGE_H, info.IMAGE_W | ||
]) | ||
} | ||
return { | ||
xs, | ||
ys | ||
} | ||
} | ||
return preprocessImage | ||
} | ||
export enum TextPreprocessing {} | ||
|
||
export enum TabularPreprocessing {} | ||
|
||
export type Preprocessing = ImagePreprocessing | TextPreprocessing | TabularPreprocessing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import { Data } from './data' | ||
|
||
export class TextData extends Data { | ||
batch (): TextData { | ||
return new TextData(this.batchedDataset, this.task, this.size) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
discojs/discojs-core/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import GPT3Tokenizer from 'gpt3-tokenizer' | ||
import { List } from 'immutable' | ||
|
||
import { Task } from '../..' | ||
import { DataLoader, DataConfig } from './data_loader' | ||
import { Dataset } from '../dataset' | ||
import { DataSplit } from '../data/data_split' | ||
import { TextData } from '../data/text_data' | ||
|
||
const BUFFER_SIZE = 50 | ||
|
||
type Tokenizer = 'gpt3' | ||
|
||
export abstract class TextLoader<Source> extends DataLoader<Source> { | ||
public readonly tokenize: (token: string) => string | ||
|
||
constructor ( | ||
task: Task, | ||
tokenizerType?: Tokenizer, | ||
public readonly delimiter = '\n' | ||
) { | ||
super(task) | ||
|
||
if (tokenizerType === undefined) { | ||
tokenizerType = 'gpt3' | ||
} | ||
|
||
switch (tokenizerType) { | ||
default: | ||
this.tokenize = new GPT3Tokenizer({ type: tokenizerType }).bpe | ||
break | ||
} | ||
} | ||
|
||
abstract loadTextDatasetFrom (source: Source): Promise<Dataset> | ||
|
||
async load (source: Source, config?: DataConfig): Promise<Dataset> { | ||
const dataset = await this.loadTextDatasetFrom(source) | ||
return config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset | ||
} | ||
|
||
async loadAll (sources: Source[], config: DataConfig): Promise<DataSplit> { | ||
const datasets = await Promise.all(sources.map(async (source) => | ||
await this.load(source, { ...config, shuffle: false }))) | ||
let dataset = List(datasets).reduce((acc: Dataset, dataset) => | ||
acc.concatenate(dataset)) | ||
dataset = config?.shuffle ? dataset.shuffle(BUFFER_SIZE) : dataset | ||
const data = await TextData.init( | ||
dataset, | ||
this.task, | ||
undefined | ||
) | ||
return { | ||
train: data | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
discojs/discojs-node/src/dataset/data_loader/text_loader.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import fs from 'node:fs' | ||
|
||
import split2 from 'split2' | ||
|
||
import { tf } from '../..' | ||
import { TextLoader } from 'core/dataset/data_loader/text_loader' | ||
import { Dataset } from 'core/dataset' | ||
import { DataConfig } from 'core/dataset/data_loader' | ||
|
||
export class NodeTextLoader extends TextLoader<string> { | ||
async loadTextDatasetFrom (source: string, config?: DataConfig): Promise<Dataset> { | ||
const prefix = 'file://' | ||
if (source.slice(0, 7) !== prefix) { | ||
source = prefix + source | ||
} | ||
// create stream being read by generator | ||
const stream = fs.createReadStream(source, { encoding: 'utf-8' }) | ||
// eslint-disable-next-line @typescript-eslint/no-this-alias | ||
const self = this | ||
|
||
async function * dataGenerator (): AsyncGenerator<tf.TensorContainer> { | ||
// TODO @s314cy | ||
const withLabels = config?.labels !== undefined | ||
stream.pipe(split2()) | ||
stream.on('data', (data) => yield self.tokenize(data)) | ||
} | ||
|
||
return tf.data.generator(dataGenerator) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import { tf } from '../..' | ||
import { Dataset } from 'core/dataset' | ||
import { TextLoader } from 'core/dataset/data_loader/text_loader' | ||
|
||
export class WebTextLoader extends TextLoader<File> { | ||
async loadTextDatasetFrom (source: File): Promise<Dataset> { | ||
return new tf.data.TextLineDataset(new tf.data.FileDataSource(source)) | ||
} | ||
} |
Oops, something went wrong.