Skip to content

Commit

Permalink
Add nemo integration (#140)
Browse files Browse the repository at this point in the history
* Add nemo support

* Fix old tokenization guidance

* Fix nemo preset
  • Loading branch information
miku448 authored Aug 5, 2024
1 parent 7e73c6f commit 8a5552a
Show file tree
Hide file tree
Showing 16 changed files with 353 additions and 243 deletions.
1 change: 1 addition & 0 deletions apps/bot-directory/views/index.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
<option value="solar">Solar</option>
<option value="cohere">Cohere</option>
<option value="wizardlm2">WizardLM-2</option>
<option value="nemo">Nemo</option>
</select>
</div>
<div class="flex gap-2">
Expand Down
3 changes: 2 additions & 1 deletion apps/interactor/src/libs/prompts/PromptBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class PromptBuilder<StrategyClass extends AbstractPromptStrategy<unknown, unknow
return recursiveBinarySearch(midIndex + 1, maxIndex, maxTokens);
}
};
const memorySize = recursiveBinarySearch(0, maxMemorySize, this.options.truncationLength - 200) - 1;
const memorySize =
recursiveBinarySearch(0, maxMemorySize, this.options.truncationLength - this.options.maxNewTokens) - 1;

return this.strategy.buildGuidancePrompt(this.options.maxNewTokens, memorySize, input);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { RootState } from '../../../../state/store'
import { RootState } from '../../../../state/store';

export default {
objectives: [],
Expand Down Expand Up @@ -493,7 +493,7 @@ export default {
],
},
version: 'v3',
} as RootState
} as RootState;

export const expectedResult = {
template:
Expand Down Expand Up @@ -532,5 +532,5 @@ export const expectedResult = {
' amused',
],
},
totalTokens: 1413,
}
totalTokens: 1360,
};
Original file line number Diff line number Diff line change
@@ -1,38 +1,30 @@
import llamaTokenizer, { Tokenizer } from '../_llama-tokenizer'
import mistralTokenizer from '../_mistral-tokenizer'
import { Tokenizer } from '../_llama-tokenizer';
import mistralTokenizer from '../_mistral-tokenizer';

export abstract class AbstractPromptStrategy<Input, Output> {
private tokenizer: Tokenizer
constructor(tokenizerSkug?: string) {
if (tokenizerSkug === 'mistral') {
this.tokenizer = mistralTokenizer
} else {
this.tokenizer = llamaTokenizer
}
private tokenizer: Tokenizer;
constructor(_tokenizerSkug?: string) {
this.tokenizer = mistralTokenizer;
}
public abstract buildGuidancePrompt(
maxNewTokens: number,
memorySize: number,
input: Input
input: Input,
): {
template: string
variables: Record<string, string | string[]>
totalTokens: number
}
template: string;
variables: Record<string, string | string[]>;
totalTokens: number;
};

public abstract completeResponse(
input: Input,
response: Output,
variables: Map<string, string>
): Output
public abstract completeResponse(input: Input, response: Output, variables: Map<string, string>): Output;

protected countTokens(template: string): number {
let maxTokens: number = 0
let maxTokens: number = 0;
template.replace(/max_tokens=(\d+)/g, (_, _maxTokens) => {
maxTokens += parseInt(_maxTokens) || 0
return ''
})
const _template = template.replace(/{{.*?}}/g, '')
return this.tokenizer.encode(_template).length + maxTokens
maxTokens += parseInt(_maxTokens) || 0;
return '';
});
const _template = template.replace(/{{.*?}}/g, '');
return this.tokenizer.encode(_template).length + maxTokens;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export class RoleplayStrategyMistral extends AbstractRoleplayStrategy {
askLine: '[/INST]# Reaction + 2 paragraphs (engaging, natural, authentic, descriptive, creative)\n',
instruction: '[INST]',
response: '[/INST]',
stops: ['[/INST]', '[INST]'],
stops: ['INST', '/INST'],
};
}
}
4 changes: 2 additions & 2 deletions apps/interactor/src/state/listeners/interaction.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ const interactionEffect = async (
new PromptBuilder<AbstractRoleplayStrategy>({
maxNewTokens: 200,
strategy: primaryStrategy,
truncationLength: truncation_length,
truncationLength: truncation_length - 150,
}),
new PromptBuilder<AbstractRoleplayStrategy>({
maxNewTokens: 200,
strategy: secondaryStrategy,
truncationLength: secondary.truncation_length,
truncationLength: secondary.truncation_length - 150,
}),
];

Expand Down
2 changes: 1 addition & 1 deletion apps/services/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@
"devDependencies": {
"nodemon": "^2.0.20"
}
}
}
1 change: 1 addition & 0 deletions apps/services/src/server.mts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Promise.all([
loadTokenizer(TokenizerType.LLAMA_3),
loadTokenizer(TokenizerType.MISTRAL),
loadTokenizer(TokenizerType.SOLAR),
loadTokenizer(TokenizerType.NEMO),
loadTokenizer(TokenizerType.COHERE),
loadTokenizer(TokenizerType.WIZARDLM2),
]).then(() => {
Expand Down
49 changes: 35 additions & 14 deletions apps/services/src/services/text/data/presets.mts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { OpenAIAphroditeConfig } from "../lib/aphroditeTokenGenerator.mjs";
import { PresetType } from "./rpModelTypes.mjs";
import { OpenAIAphroditeConfig } from '../lib/aphroditeTokenGenerator.mjs';
import { PresetType } from './rpModelTypes.mjs';

//truncation_length: 4096,

Expand All @@ -26,7 +26,7 @@ export const presets = new Map<PresetType, OpenAIAphroditeConfig>([
use_beam_search: false,
length_penalty: 1.0,
early_stopping: false,
stop: ["\n###", "</s>", "<|", "\n#", "\n\n\n"],
stop: ['\n###', '</s>', '<|', '\n#', '\n\n\n'],
ignore_eos: false,
skip_special_tokens: true,
spaces_between_special_tokens: true,
Expand Down Expand Up @@ -55,7 +55,7 @@ export const presets = new Map<PresetType, OpenAIAphroditeConfig>([
use_beam_search: false,
length_penalty: 1.0,
early_stopping: false,
stop: ["<|user|>", "<|model|>", "<|system|>"],
stop: ['<|user|>', '<|model|>', '<|system|>'],
ignore_eos: false,
skip_special_tokens: true,
spaces_between_special_tokens: true,
Expand Down Expand Up @@ -83,7 +83,7 @@ export const presets = new Map<PresetType, OpenAIAphroditeConfig>([
use_beam_search: false,
length_penalty: 1.0,
early_stopping: false,
stop: ["\n###", "</s>", "<|", "\n#", "\n\n\n"],
stop: ['\n###', '</s>', '<|', '\n#', '\n\n\n'],
ignore_eos: false,
skip_special_tokens: true,
spaces_between_special_tokens: true,
Expand Down Expand Up @@ -112,15 +112,36 @@ export const presets = new Map<PresetType, OpenAIAphroditeConfig>([
use_beam_search: false,
length_penalty: 1.0,
early_stopping: false,
stop: [
"\n###",
"</s>",
"<|eot_id|>",
"<|end_of_text|>",
"<|",
"\n#",
"\n\n\n",
],
stop: ['\n###', '</s>', '<|eot_id|>', '<|end_of_text|>', '<|', '\n#', '\n\n\n'],
ignore_eos: false,
skip_special_tokens: true,
spaces_between_special_tokens: true,
},
],
[
PresetType.NEMO,
{
n: 1,
best_of: 1,
presence_penalty: 0.0,
frequency_penalty: 0.0,
repetition_penalty: 0,
temperature: 0.4,
min_p: 0.1,
top_p: 0.1,
top_k: 0,
top_a: 0,
tfs: 1,
eta_cutoff: 0,
epsilon_cutoff: 0,
typical_p: 1,
mirostat_mode: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
use_beam_search: false,
length_penalty: 1.0,
early_stopping: false,
stop: ['\n###', '</s>', '[INST]', '[/INST]', '\n#', '\n\n\n'],
ignore_eos: false,
skip_special_tokens: true,
spaces_between_special_tokens: true,
Expand Down
2 changes: 2 additions & 0 deletions apps/services/src/services/text/data/rpModelTypes.mts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export enum RPModelTokenizers {
LLAMA3 = 'llama3',
COHERE = 'cohere',
WIZARDLM2 = 'wizardlm2',
NEMO = 'nemo',
}

export enum RPModelStrategy {
Expand All @@ -20,6 +21,7 @@ export enum PresetType {
LLAMA_PRECISE = 'LLAMA_PRECISE',
MINIMAL_WORK = 'MINIMAL_WORK',
STHENO_V3 = 'STHENO_V3',
NEMO = 'NEMO',
}

export enum RPModelPermission {
Expand Down
1 change: 1 addition & 0 deletions apps/services/src/services/text/index.mts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const getTokenizer = (_tokenizer: string): Guidance.Tokenizer.AbstractTokenizer
if (_tokenizer === 'mistral') return tokenizers.get(TokenizerType.MISTRAL)!;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
if (_tokenizer === 'solar') return tokenizers.get(TokenizerType.SOLAR)!;
if (_tokenizer === 'nemo') return tokenizers.get(TokenizerType.NEMO)!;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
if (_tokenizer === 'llama3') return tokenizers.get(TokenizerType.LLAMA_3)!;
if (_tokenizer === 'wizardlm2')
Expand Down
Loading

0 comments on commit 8a5552a

Please sign in to comment.