#pragma once #include "llama-impl.h" #include struct llama_vocab; // grammar element type enum llama_gretype { // end of rule definition LLAMA_GRETYPE_END = 0, // start of alternate definition for rule LLAMA_GRETYPE_ALT = 1, // non-terminal element: reference to rule LLAMA_GRETYPE_RULE_REF = 2, // terminal element: character (code point) LLAMA_GRETYPE_CHAR = 3, // inverse char(s) ([^a], [^a-b] [^abc]) LLAMA_GRETYPE_CHAR_NOT = 4, // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // be an inclusive range ([a-z]) LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 6, // any character (.) LLAMA_GRETYPE_CHAR_ANY = 7, }; typedef struct llama_grammar_element { enum llama_gretype type; uint32_t value; // Unicode code point or rule ID } llama_grammar_element; struct llama_partial_utf8 { uint32_t value; // bit value so far (unshifted) int n_remain; // num bytes remaining; -1 indicates invalid sequence }; struct llama_grammar_candidate { size_t index; const uint32_t * code_points; llama_partial_utf8 partial_utf8; }; using llama_grammar_rule = std::vector< llama_grammar_element>; using llama_grammar_stack = std::vector; using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, uint32_t chr, llama_grammar_stacks & stacks_new); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, const llama_grammar_candidates & candidates); struct llama_grammar_parser { std::map symbol_ids; llama_grammar_rules rules; llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char * src, size_t len); uint32_t generate_symbol_id(const std::string & base_name); void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); const char * parse_alternates( const char * src, const std::string & rule_name, uint32_t rule_id, bool is_nested); const char * parse_sequence( const char * src, const std::string & rule_name, llama_grammar_rule & rule, bool is_nested); const char * parse_rule(const char * src); bool parse(const char * src); void print(FILE * file); }; struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; }; // // internal API // // note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index); struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); // TODO: move the API below as member functions of llama_grammar void llama_grammar_apply_impl( const struct llama_grammar & grammar, llama_token_data_array * cur_p); void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token);