Skip to content

Commit

Permalink
Added context free grammar constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
grantslatton committed May 14, 2023
1 parent 08737ef commit 007e26a
Show file tree
Hide file tree
Showing 6 changed files with 1,042 additions and 4 deletions.
8 changes: 7 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") {
} else if (arg == "--grammar") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.token_grammar_path = argv[i];
} else if (arg == "-b" || arg == "--batch_size") {
if (++i >= argc) {
invalid_param = true;
break;
Expand Down
5 changes: 3 additions & 2 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string token_grammar_path = ""; // path to file containing serialized token validator
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

std::string lora_adapter = ""; // lora adapter path
Expand Down
11 changes: 11 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}


// load input from params.validator_path
std::string token_grammar_path = params.token_grammar_path;
void* grammar = nullptr;
if (!token_grammar_path.empty()) {
fprintf(stderr, "%s: attempting to parse token grammar from '%s'\n", __func__, token_grammar_path.c_str());
grammar = llama_load_token_grammar_from_path(token_grammar_path.c_str());
}

// determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
// uncomment the "used_mem" line in llama.cpp to see the results
if (params.mem_test) {
Expand Down Expand Up @@ -420,6 +429,7 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// Apply penalties
llama_grammar_penalty(ctx, &candidates_p, grammar);
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
Expand Down Expand Up @@ -459,6 +469,7 @@ int main(int argc, char ** argv) {

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
llama_grammar_accept_token(ctx, id, grammar);
}

// replace end of text token with newline token when in interactive mode
Expand Down
Loading

1 comment on commit 007e26a

@Sciumo
Copy link

@Sciumo Sciumo commented on 007e26a May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any tests to add as well? Even the ones you tweeted.

Please sign in to comment.