From 007e26a99d485007f724957fa8545331ab8d50c3 Mon Sep 17 00:00:00 2001 From: Grant Slatton Date: Sat, 13 May 2023 18:25:51 -0700 Subject: [PATCH] Added context free grammar constraints --- examples/common.cpp | 8 +- examples/common.h | 5 +- examples/main/main.cpp | 11 + grammar.py | 668 +++++++++++++++++++++++++++++++++++++++++ llama.cpp | 350 ++++++++++++++++++++- llama.h | 4 + 6 files changed, 1042 insertions(+), 4 deletions(-) create mode 100644 grammar.py diff --git a/examples/common.cpp b/examples/common.cpp index 86c1eef41b475..03fb7ce8fa12b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -231,7 +231,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.mirostat_tau = std::stof(argv[i]); - } else if (arg == "-b" || arg == "--batch-size") { + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.token_grammar_path = argv[i]; + } else if (arg == "-b" || arg == "--batch_size") { if (++i >= argc) { invalid_param = true; break; diff --git a/examples/common.h b/examples/common.h index 717838f06e064..9adb490f13261 100644 --- a/examples/common.h +++ b/examples/common.h @@ -48,8 +48,9 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with + std::string token_grammar_path = ""; // path to file containing serialized token validator + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8543414dd0fbb..055d9c2490ef2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -117,6 +117,15 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } + + // load input from params.validator_path + std::string token_grammar_path = params.token_grammar_path; + void* grammar = nullptr; + if (!token_grammar_path.empty()) { + fprintf(stderr, "%s: attempting to parse token grammar from '%s'\n", __func__, token_grammar_path.c_str()); + grammar = llama_load_token_grammar_from_path(token_grammar_path.c_str()); + } + // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters // uncomment the "used_mem" line in llama.cpp to see the results if (params.mem_test) { @@ -420,6 +429,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // Apply penalties + llama_grammar_penalty(ctx, &candidates_p, grammar); float nl_logit = logits[llama_token_nl()]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, @@ -459,6 +469,7 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); + llama_grammar_accept_token(ctx, id, grammar); } // replace end of text token with newline token when in interactive mode diff --git a/grammar.py b/grammar.py new file mode 100644 index 0000000000000..a890e8d354219 --- /dev/null +++ b/grammar.py @@ -0,0 +1,668 @@ +# Define a little parser combinator microlibrary for help parsing the grammars + +def none_of(chars): + def inner(s): + if s[0] in chars: + raise ValueError("unexpected " + s[0]) + return (s[1:], s[0]) + return inner + +def one_of(chars): + def inner(s): + if not s or s[0] not in chars: + raise ValueError("expected one of " + chars) + return (s[1:], s[0]) + return inner + +def series(*parsers): + def inner(s): + result = [] + for parser in parsers: + (s, value) = parser(s) + result.append(value) + return (s, result) + return inner + +def alt(*parsers): + def inner(s): + exceptions = [] + for parser in parsers: + try: + return parser(s) + except ValueError as e: + exceptions.append(e) + raise ValueError("expected one of " + ", ".join(map(str, exceptions)) + " but got " + s[:5]) + return inner + +def many(parser): + def inner(s): + result = [] + while True: + try: + (s, value) = parser(s) + result.append(value) + except ValueError: + break + return (s, result) + return inner + +def maybe(parser): + def inner(s): + try: + (s, value) = parser(s) + return (s, value) + except ValueError: + return (s, None) + return inner + +def many1(parser): + def inner(s): + (s, value) = parser(s) + (s, values) = many(parser)(s) + return (s, [value] + values) + return inner + +def intersperse(parser, sep_parser): + def inner(s): + try: + (s, value) = parser(s) + (s, values) = many(series(sep_parser, parser))(s) + return (s, [value] + [v for (_, v) in values]) + except ValueError: + return (s, []) + return inner + +def span_spaces(): + def inner(s): + (s, _) = many(one_of(" \t"))(s) + return (s, None) + return inner + +def spaces(): + def inner(s): + (s, _) = many(one_of(" \t\r\n"))(s) + return (s, None) + return inner + +def nl(): + return alt(literal("\r\n"), literal("\n")) + +def alpha(): + return one_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +def digit(): + return one_of("0123456789") + +def literal(string): + return series(*[one_of(c) for c in string]) + +# Grammar data structures + +class Terminal: + def __init__(self, value): + if not isinstance(value, str): + raise TypeError("Terminal value must be a string") + self.bytes = [ord(c) for c in value] + + def __eq__(self, other): + return self.bytes == other.bytes + +class NonTerminal: + def __init__(self, rule_name): + self.rule_name = rule_name + + def __eq__(self, other): + return self.rule_name == other.rule_name + +class Branch: + def __init__(self, elements): + if not isinstance(elements, list): + raise TypeError("Branch elements must be a list") + if not all(isinstance(element, (Terminal, NonTerminal)) for element in elements): + raise TypeError("Branch elements must be a list of Terminals and NonTerminals") + self.elements = elements + + def __eq__(self, other): + return self.elements == other.elements + +class Rule: + def __init__(self, rule_name, branches): + self.rule_name = rule_name + self.branches = branches + + def __eq__(self, other): + return self.rule_name == other.rule_name and self.branches == other.branches + +class Grammar: + def __init__(self, rules, main_rule): + self.rules = rules + self.main_rule = main_rule + + def add_rule(self, rule): + if isinstance(rule, str): + (rem, rule) = parse_rule()(rule) + if rem: + raise ValueError("rule had trailing characters: " + rem) + + if rule.rule_name in self.rules: + if rule != self.rules[rule.rule_name]: + raise ValueError("rule name collision") + else: + self.rules[rule.rule_name] = rule + + def grammar(self): + def get_rule_names(node, visited=None): + if visited is None: + visited = set() + if isinstance(node, NonTerminal): + return get_rule_names(self.rules[node.rule_name], visited) + elif isinstance(node, Branch): + for element in node.elements: + get_rule_names(element, visited) + elif isinstance(node, Rule): + if node.rule_name not in visited: + visited.add(node.rule_name) + for branch in node.branches: + get_rule_names(branch, visited) + return visited + + all_rules = get_rule_names(self.main_rule) + rule_ids = { rule_name : i for (i, rule_name) in enumerate(all_rules) } + id = lambda rule: rule_ids[rule.rule_name] + + result = "" + result += f"{id(self.main_rule)} {len(all_rules)}\n" + for rule_name in rule_ids: + rule = self.rules[rule_name] + result += f"{id(rule)} {len(rule.branches)}\n" + for branch in rule.branches: + result += f"{len(branch.elements)}\n" + for element in branch.elements: + if isinstance(element, Terminal): + result += f"0 {len(element.bytes)} {' '.join(map(str, element.bytes))}\n" + elif isinstance(element, NonTerminal): + result += f"1 {id(element)}\n" + else: + raise TypeError("unsupported element type") + + return result + +# Grammar parsing + +def rule_char(): + return alt(alpha(), digit(), one_of("_-.")) + +def parse_rule_name(): + def inner(s): + (s, chars) = many1(rule_char())(s) + return (s, "".join(chars)) + return inner + +def parse_terminal(): + def double_quoted(s): + (s, _) = literal("\"")(s) + (s, chars) = many(alt(none_of("\""), literal("\\\"")))(s) + (s, _) = literal("\"")(s) + return (s, Terminal("".join(chars))) + + def single_quoted(s): + (s, _) = literal("'")(s) + (s, chars) = many(alt(none_of("'"), literal("\\'")))(s) + (s, _) = literal("'")(s) + return (s, Terminal("".join(chars))) + + return alt(double_quoted, single_quoted) + +def parse_non_terminal(): + def inner(s): + (s, _) = literal("<")(s) + (s, rule_name) = parse_rule_name()(s) + (s, _) = literal(">")(s) + return (s, NonTerminal(rule_name)) + return inner + +def parse_element(): + return alt(parse_terminal(), parse_non_terminal()) + +def parse_branch(): + def inner(s): + (s, elements) = intersperse(parse_element(), span_spaces())(s) + return (s, Branch(elements)) + return inner + +def parse_rule_body(): + return intersperse(parse_branch(), series(span_spaces(), literal("|"), span_spaces())) + +def parse_rule(): + def inner(s): + (s, rule_name) = parse_rule_name()(s) + (s, _) = span_spaces()(s) + (s, _) = literal("=")(s) + (s, _) = span_spaces()(s) + (s, branches) = parse_rule_body()(s) + (s, _) = span_spaces()(s) + return (s, Rule(rule_name, branches)) + return inner + +def grammar(): + def inner(s): + (s, rules) = intersperse(parse_rule(), nl())(s) + (s, _) = spaces()(s) + return (s, rules) + return inner + +def type_char(): + return alt(alpha(), digit(), one_of("_")) + +def parse_type_name(): + def inner(s): + (s, chars) = many1(type_char())(s) + return (s, "".join(chars)) + return inner + +# JSON spec data structures + +class JsonType: + def __eq__(self, other): + return repr(self) == repr(other) + + def __hash__(self): + return hash(repr(self)) + + def rule_name(self): + return f"rule_{abs(hash(repr(self)))}" + + def expr_rule_name(self): + return f"<{self.rule_name()}>" + + def visit_types(self): + yield self + +class JsonBoolean(JsonType): + def __repr__(self): + return "boolean" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = "true" | "false"') + +class JsonInteger(JsonType): + def __repr__(self): + return "integer" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = ') + +class JsonUnsigned(JsonType): + def __repr__(self): + return "unsigned" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = ') + +class JsonFloat(JsonType): + def __repr__(self): + return "float" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = | "." ') + +class JsonString(JsonType): + def __repr__(self): + return "string" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = ' + ''' '"' '"' '''.strip()) + +class JsonNull(JsonType): + def __repr__(self): + return "null" + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = "null"') + +class JsonTuple(JsonType): + def __init__(self, types): + self.types = types + + def __repr__(self): + return "[" + ", ".join(map(repr, self.types)) + "]" + + def visit_types(self): + yield from super().visit_types() + for t in self.types: + yield from t.visit_types() + + def add_to_grammar(self, grammar, user_types): + for t in self.types: + t.add_to_grammar(grammar, user_types) + + rule_body = ' "[" ' + for (i, t) in enumerate(self.types): + if i > 0: + rule_body += ' ", " ' + rule_body += f'{self.types[i].expr_rule_name()}' + rule_body += ' "]" ' + + grammar.add_rule(f'{self.rule_name()} = {rule_body}') + +class JsonObject(JsonType): + def __init__(self, fields): + self.fields = fields + + def __repr__(self): + return "{" + ", ".join(map(lambda f: f"{f[0]}: {f[1]}", self.fields)) + "}" + + def visit_types(self): + yield from super().visit_types() + for (_, t) in self.fields: + yield from t.visit_types() + + def add_to_grammar(self, grammar, user_types): + for (_, t) in self.fields: + t.add_to_grammar(grammar, user_types) + + rule_body = ' "{" ' + for (i, (k, t)) in enumerate(self.fields): + if i > 0: + rule_body += ' "," ' + rule_body += f"' \"{k}\"'" + rule_body += f' ": " {t.expr_rule_name()} " "' + rule_body += ' "}" ' + + grammar.add_rule(f'{self.rule_name()} = {rule_body}') + +class JsonUnion(JsonType): + def __init__(self, types): + types = { repr(t) : t for t in types} + self.types = sorted(list(types.values()), key=repr) + + def __repr__(self): + return " | ".join(map(repr, self.types)) + + def visit_types(self): + yield from super().visit_types() + for t in self.types: + yield from t.visit_types() + + def add_to_grammar(self, grammar, user_types): + for t in self.types: + t.add_to_grammar(grammar, user_types) + + rule_body = '' + for (i, t) in enumerate(self.types): + if i > 0: + rule_body += ' | ' + rule_body += f'{t.expr_rule_name()}' + + grammar.add_rule(f'{self.rule_name()} = {rule_body}') + +class JsonArray(JsonType): + def __init__(self, value_type): + self.value_type = value_type + + def __repr__(self): + return f"Array<{self.value_type}>" + + def visit_types(self): + yield from super().visit_types() + yield from self.value_type.visit_types() + + def add_to_grammar(self, grammar, user_types): + self.value_type.add_to_grammar(grammar, user_types) + value_rule_name = self.value_type.rule_name() + value_expr_rule_name = self.value_type.rule_name() + many_rule_name = f'{value_rule_name}_many' + many_expr_rule_name = f'<{many_rule_name}>' + grammar.add_rule(f'{many_rule_name} = {value_expr_rule_name} | {value_expr_rule_name} ", " {many_expr_rule_name}') + grammar.add_rule(f'{self.rule_name()} = "[" {many_expr_rule_name} "]"') + +class JsonUserType(JsonType): + def __init__(self, type_name): + self.type_name = type_name + + def __repr__(self): + return self.type_name + + def add_to_grammar(self, grammar, user_types): + user_type = user_types[self.type_name] + user_type.add_to_grammar(grammar, user_types) + grammar.add_rule(f'{self.rule_name()} = {user_type.expr_rule_name()}') + +class JsonStringLiteralType(JsonType): + def __init__(self, value): + self.value = value + + def __repr__(self): + return f'"{self.value}"' + + def add_to_grammar(self, grammar, user_types): + grammar.add_rule(f'{self.rule_name()} = "{self.value}"') + +class JsonSpec: + def __init__(self, types): + self.types = types + + def grammar(self): + g = Grammar({}, None) + g.add_rule('nil = ') + g.add_rule('digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"') + g.add_rule('digits = | ') + g.add_rule('sign = "-" | ') + g.add_rule('alphanumspace = ' + " | ".join(map(lambda c: f'"{c}"', "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "))) + g.add_rule('escaped_double_quote = \'\\"\'') + g.add_rule('string_char = | ') + g.add_rule('string_chars = | ') + + types_by_name = { n : t for (n, t) in self.types } + + for (_, t) in self.types: + t.add_to_grammar(g, types_by_name) + + main_type = self.types[-1][1] + main_type_name = main_type.rule_name() + + g.main_rule = g.rules[main_type_name] + + return g.grammar() + +# JSON spec parsing + +def parse_boolean_type(): + def inner(s): + (s, _) = literal("boolean")(s) + return (s, JsonBoolean()) + return inner + +def parse_integer_type(): + def inner(s): + (s, _) = literal("integer")(s) + return (s, JsonInteger()) + return inner + +def parse_unsigned_type(): + def inner(s): + (s, _) = literal("unsigned")(s) + return (s, JsonUnsigned()) + return inner + +def parse_float_type(): + def inner(s): + (s, _) = alt(literal("float"), literal("double"), literal("number"))(s) + return (s, JsonFloat()) + return inner + +def parse_string_type(): + def inner(s): + (s, _) = literal("string")(s) + return (s, JsonString()) + return inner + +def parse_null_type(): + def inner(s): + (s, _) = literal("null")(s) + return (s, JsonNull()) + return inner + +def parse_tuple_type(): + def inner(s): + (s, _) = literal("[")(s) + (s, types) = intersperse(parse_type_body(), series(spaces(), literal(","), spaces()))(s) + (s, _) = series(maybe(literal(",")), spaces())(s) + (s, _) = literal("]")(s) + return (s, JsonTuple(types)) + return inner + +def parse_object_key(): + def inner(s): + (s, _) = literal('"')(s) + (s, key) = many1(alt(none_of('"'), literal('\\"')))(s) + (s, _) = literal('"')(s) + return (s, "".join(key)) + return inner + +def parse_object_field(): + def inner(s): + (s, key) = parse_object_key()(s) + (s, _) = series(span_spaces(), literal(":"), span_spaces())(s) + (s, value) = parse_type_body()(s) + return (s, (key, value)) + return inner + +def parse_object_type(): + def inner(s): + (s, _) = series(literal("{"), spaces())(s) + (s, fields) = intersperse(parse_object_field(), series(spaces(), literal(","), spaces()))(s) + (s, _) = series(maybe(literal(",")), spaces(), literal("}"))(s) + return (s, JsonObject(fields)) + return inner + +def parse_array_type(): + def inner(s): + (s, _) = literal("Array<")(s) + (s, value_type) = parse_type_body()(s) + (s, _) = literal(">")(s) + return (s, JsonArray(value_type)) + return inner + +def parse_user_type(): + def inner(s): + (s, type_name) = parse_type_name()(s) + return (s, JsonUserType(type_name)) + return inner + +def parse_string_literal_type(): + def inner(s): + (s, _) = literal('"')(s) + (s, value) = many1(alt(none_of('"'), literal('\\"')))(s) + (s, _) = literal('"')(s) + return (s, JsonStringLiteralType("".join(value))) + return inner + +def parse_single_type_body(): + return alt( + parse_boolean_type(), + parse_integer_type(), + parse_unsigned_type(), + parse_null_type(), + parse_float_type(), + parse_string_type(), + parse_tuple_type(), + parse_object_type(), + parse_array_type(), + parse_string_literal_type(), + parse_user_type(), + ) + +def parse_type_body(): + def inner(s): + (s, types) = intersperse(parse_single_type_body(), series(span_spaces(), literal("|"), span_spaces()))(s) + if len(types) == 1: + return (s, types[0]) + else: + return (s, JsonUnion(types)) + return inner + +def parse_type_definition(): + def inner(s): + (s, _) = series(spaces(), literal("type "))(s) + (s, type_name) = parse_type_name()(s) + (s, _) = series(spaces(), literal("="), spaces())(s) + (s, type_body) = parse_type_body()(s) + (s, _) = series(spaces(), literal(";"))(s) + return (s, (type_name, type_body)) + return inner + +def json_spec(s): + (s, types) = many1(parse_type_definition())(s) + (s, _) = spaces()(s) + if s: + raise Exception("JSON spec had unparsed trailing characters:\n" + s) + return JsonSpec(types) + +HELP = """ + +--grammar : parse a full grammar from a file +--grammar : parse a full grammar from a string + +The last rule in the grammar is the main rule. + +Example grammar: +person = "Bob" | "Alice" | "Charlie" +digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" +age = digit | digit digit +sentence = person "is" age "years old" + +--json : parse a json spec from a file +--json : parse a json spec from a string + +The last type declaration in the json spec is the main type. + +Example JSON spec: +type FirstNameLastName = [string, string]; +type Result = { + "country": string, + "population": integer, + "percent_retired": float, + "cities": Array<{ "city": string, "is_capital": boolean }>, + "head_of_state": FirstNameLastName | null, +}; + +""".strip() + +def main(): + import sys + # modes: parse a full grammar, parse a json spec + if len(sys.argv) < 3 or sys.argv[1] in ["--help", "-h", "help"]: + print(HELP) + sys.exit(1) + + if sys.argv[1] == "--grammar": + grammar_string = None + # try to open the file, otherwise interpret the argument as a string + try: + with open(sys.argv[2], "r") as f: + grammar_string = f.read() + except FileNotFoundError: + grammar_string = sys.argv[2] + + (s, rules) = grammar()(grammar_string.strip()) + if s: + print("Grammar had unparsed trailing characters:\n" + s) + sys.exit(1) + rules_by_name = { rule.rule_name : rule for rule in rules } + main_rule = rules[-1] + print(Grammar(rules_by_name, main_rule).grammar()) + + elif sys.argv[1] == "--json": + json_string = None + # try to open the file, otherwise interpret the argument as a string + try: + with open(sys.argv[2], "r") as f: + json_string = f.read() + except FileNotFoundError: + json_string = sys.argv[2] + + print(json_spec(json_string).grammar()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index 98f49abd7cf48..8e020379ee969 100644 --- a/llama.cpp +++ b/llama.cpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include #include @@ -32,6 +34,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -200,6 +203,24 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + + struct token_trie { + std::unordered_map children; + std::vector tokens; + + void insert(const llama_vocab::token & tok, llama_vocab::id id) { + token_trie * node = this; + for (char c : tok) { + if (node->children.count(c) == 0) { + node->children[c] = token_trie(); + } + node = &node->children.at(c); + } + node->tokens.push_back(id); + } + }; + + token_trie trie; }; struct llama_context { @@ -465,6 +486,7 @@ struct llama_file_loader { } vocab.token_to_id[word] = i; + vocab.trie.insert(word, i); auto & tok_score = vocab.id_to_token[i]; tok_score.tok = std::move(word); @@ -1820,7 +1842,6 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } - llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); auto N = float(llama_n_vocab(ctx)); @@ -2795,6 +2816,333 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return true; } +// +// Token constraints +// + +typedef size_t rule_id; +enum grammar_element_kind { + TERMINAL, + REFERENCE +}; + +struct grammar_element { + grammar_element_kind kind; + std::vector terminal; + rule_id reference; + + grammar_element(std::vector terminal) : kind(grammar_element_kind::TERMINAL), terminal(terminal) {} + grammar_element(rule_id reference) : kind(grammar_element_kind::REFERENCE), reference(reference) {} + + grammar_element(const grammar_element& other) { + kind = other.kind; + if (kind == grammar_element_kind::TERMINAL) { + terminal = other.terminal; + } else { + reference = other.reference; + } + } + + grammar_element& operator=(const grammar_element& other) { + if (this != &other) { + kind = other.kind; + if (kind == grammar_element_kind::TERMINAL) { + terminal = other.terminal; + } else { + reference = other.reference; + } + } + return *this; + } + + bool operator==(const grammar_element& other) const { + if (kind != other.kind) { + return false; + } + if (kind == grammar_element_kind::TERMINAL) { + return terminal == other.terminal; + } + return reference == other.reference; + } +}; + + +size_t hash_combine(size_t seed, size_t value) { + return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2)); +} + +typedef std::vector branch; +typedef std::vector rule_body; +typedef std::unordered_map grammar_rules; +typedef size_t parse_hash; +struct partial_parse; +typedef std::unordered_map partial_parses; + +struct partial_parse { + std::shared_ptr grammar; + std::deque remaining_literal; + std::deque remaining_elements; + + partial_parse(std::shared_ptr g, std::vector rem_elements) + : grammar(g), remaining_elements(rem_elements.begin(), rem_elements.end()) {} + + size_t hash() const { + std::hash::type> enum_hash; + size_t seed = 842502087; + + for (const auto& elem : remaining_elements) { + seed = hash_combine(seed, enum_hash(static_cast::type>(elem.kind))); + + if (elem.kind == grammar_element_kind::TERMINAL) { + std::hash char_hash; + for (const auto& ch : elem.terminal) { + seed = hash_combine(seed, char_hash(ch)); + } + } else if (elem.kind == grammar_element_kind::REFERENCE) { + std::hash rule_name_hash; + seed = hash_combine(seed, rule_name_hash(elem.reference)); + } + } + + std::hash char_hash; + for (const auto& ch : remaining_literal) { + seed = hash_combine(seed, char_hash(ch)); + } + + return seed; + } + + + bool equals(const partial_parse& other) const { + return remaining_literal == other.remaining_literal && + remaining_elements == other.remaining_elements; + } + + partial_parses accept_char(char c) { + partial_parses result; + while (remaining_literal.empty()) { + if (remaining_elements.empty()) { + break; + } + + grammar_element elem = remaining_elements.front(); + remaining_elements.pop_front(); + + if (elem.kind == grammar_element_kind::TERMINAL) { + for (char ch : elem.terminal) { + remaining_literal.push_back(ch); + } + } else if (elem.kind == grammar_element_kind::REFERENCE) { + auto rule = grammar->find(elem.reference); + if (rule != grammar->end()) { + for (const auto& alternative : rule->second) { + partial_parse new_parse(*this); + for (auto it = alternative.rbegin(); it != alternative.rend(); ++it) { + new_parse.remaining_elements.push_front(*it); + } + auto new_result = new_parse.accept_char(c); + result.insert(new_result.begin(), new_result.end()); + } + return result; + } + } + } + + if(remaining_literal.empty()) { + if (c == '\0') { + result.insert({hash(), *this}); + } + } else { + if (remaining_literal.front() == c) { + remaining_literal.pop_front(); + result.insert({hash(), *this}); + } + } + return result; + } +}; + +struct token_grammar { + partial_parses parse_map; + + token_grammar(rule_id start_rule_name, std::shared_ptr grammar) { + rule_body start_rule = grammar->at(start_rule_name); + for (const auto& alternative : start_rule) { + partial_parse parse(grammar, alternative); + parse_map.insert({parse.hash(), parse}); + } + } + + token_grammar(const token_grammar& other) { + parse_map = other.parse_map; + } + + bool accept(char c) { + partial_parses new_parse_map; + for (auto& parse_entry : parse_map) { + auto accepted = parse_entry.second.accept_char(c); + new_parse_map.insert(accepted.begin(), accepted.end()); + } + if (new_parse_map.empty()) { + return false; + } + parse_map = new_parse_map; + return true; + } + + bool accept_str(const char *s) { + if(strlen(s) == 0) { + return accept('\0'); + } + for (const char *c = s; *c != '\0'; c++) { + if (!accept(*c)) { + return false; + } + } + return true; + } + + std::unique_ptr clone() const { + return std::unique_ptr(new token_grammar(*this)); + } + + // Collect all the tokens that this token filter can accept + void collect_ids(const size_t depth, const llama_vocab::token_trie& trie, std::unordered_set& ids) { + // The depth > 0 check is to avoid collecting the empty string unless we are at the root of the trie + if(depth > 0 || clone()->accept('\0')) { + ids.insert(trie.tokens.begin(), trie.tokens.end()); + } + + for (auto it = trie.children.begin(); it != trie.children.end(); ++it) { + char c = it->first; + const auto& child_trie = it->second; + + auto cloned_validator = clone(); + if(cloned_validator->accept(c)) { + cloned_validator->collect_ids(depth+1, child_trie, ids); + } + } + } +}; + +// Grammar of the grammar: +// file: start_rule_name rule_count rule* +// rule_count: int +// rule: rule_name rule_body +// rule_name: int +// rule_body: alternative_count alternative* +// alternative_count: int +// alternative: element_count element* +// element_count: int +// element: terminal terminal_value | non_terminal rule_name +// terminal: 0 +// non_terminal: 1 +// terminal_value: byte_count byte* +// byte_count: int +std::unique_ptr parse_token_grammar(const char *s) { + std::shared_ptr grammar = std::shared_ptr(new grammar_rules()); + int start_rule_name; + int rule_count; + int bytes_read; + sscanf(s, "%d %d%n", &start_rule_name, &rule_count, &bytes_read); + s += bytes_read; + + for (int i = 0; i < rule_count; i++) { + rule_id rule_name; + int alternative_count; + sscanf(s, "%zd %d%n", &rule_name, &alternative_count, &bytes_read); + s += bytes_read; + + rule_body production_rule; + + for (int j = 0; j < alternative_count; j++) { + int element_count; + sscanf(s, "%d%n", &element_count, &bytes_read); + s += bytes_read; + + branch alternative; + + for (int k = 0; k < element_count; k++) { + int element_type; + sscanf(s, "%d%n", &element_type, &bytes_read); + s += bytes_read; + + if (element_type == 0) { + int byte_count; + sscanf(s, "%d%n", &byte_count, &bytes_read); + s += bytes_read; + + std::vector bytes; + for (int l = 0; l < byte_count; l++) { + int byte_int; + sscanf(s, "%d%n", &byte_int, &bytes_read); + s += bytes_read; + bytes.push_back((char) byte_int); + } + grammar_element element_data(bytes); + alternative.push_back(element_data); + } else if (element_type == 1) { + rule_id rule_name; + sscanf(s, "%zd%n", &rule_name, &bytes_read); + s += bytes_read; + grammar_element element_data(rule_name); + alternative.push_back(element_data); + } else { + fprintf(stderr, "%s: Invalid element type: %d\n", __func__, element_type); + throw std::runtime_error("Invalid element type"); + } + + } + + production_rule.push_back(alternative); + } + + grammar->insert({rule_name, production_rule}); + } + + auto token_grammar_instance = std::unique_ptr(new token_grammar(start_rule_name, grammar)); + return token_grammar_instance; +} + +void* llama_load_token_grammar_from_path(const char *path) { + llama_file file(path, "rb"); + std::string s = file.read_string(file.size); + const char *c = s.c_str(); + std::unique_ptr unique_validator = parse_token_grammar(c); + return unique_validator.release(); +} + +void llama_grammar_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const void* filter_ptr) { + if(!filter_ptr) { + return; + } + + const auto *filter = static_cast(filter_ptr); + + std::unordered_set valid_ids; + filter->clone()->collect_ids(0, ctx->vocab.trie, valid_ids); + + for (size_t i = 0; i < candidates->size; ++i) { + auto candidate = candidates->data[i].id; + auto as_str = ctx->vocab.id_to_token[candidate].tok.c_str(); + if (valid_ids.find(candidate) == valid_ids.end()) { + candidates->data[i].logit -= 1000.0f; + } + } +} + +void llama_grammar_accept_token(struct llama_context * ctx, llama_token id, void* filter_ptr) { + if(!filter_ptr) { + return; + } + + auto *filter = static_cast(filter_ptr); + auto as_str = ctx->vocab.id_to_token[id].tok.c_str(); + if (!filter->accept_str(as_str)) { + throw std::runtime_error("filter rejected token"); + } +} + int llama_eval( struct llama_context * ctx, const llama_token * tokens, diff --git a/llama.h b/llama.h index 21cba8cf61061..9c81cbc4db2b6 100644 --- a/llama.h +++ b/llama.h @@ -191,6 +191,10 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); LLAMA_API llama_token llama_token_nl(); + void* llama_load_token_grammar_from_path(const char *path); + LLAMA_API void llama_grammar_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const void* filter_ptr); + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, llama_token id, void* filter_ptr); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.