|
#pragma once |
|
|
|
#include "llama-impl.h" |
|
|
|
#include <map> |
|
|
|
struct llama_vocab; |
|
|
|
|
|
enum llama_gretype { |
|
|
|
LLAMA_GRETYPE_END = 0, |
|
|
|
|
|
LLAMA_GRETYPE_ALT = 1, |
|
|
|
|
|
LLAMA_GRETYPE_RULE_REF = 2, |
|
|
|
|
|
LLAMA_GRETYPE_CHAR = 3, |
|
|
|
|
|
LLAMA_GRETYPE_CHAR_NOT = 4, |
|
|
|
|
|
|
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, |
|
|
|
|
|
|
|
LLAMA_GRETYPE_CHAR_ALT = 6, |
|
|
|
|
|
LLAMA_GRETYPE_CHAR_ANY = 7, |
|
}; |
|
|
|
typedef struct llama_grammar_element { |
|
enum llama_gretype type; |
|
uint32_t value; |
|
} llama_grammar_element; |
|
|
|
struct llama_partial_utf8 { |
|
uint32_t value; |
|
int n_remain; |
|
}; |
|
|
|
struct llama_grammar_candidate { |
|
size_t index; |
|
const uint32_t * code_points; |
|
llama_partial_utf8 partial_utf8; |
|
}; |
|
|
|
using llama_grammar_rule = std::vector< llama_grammar_element>; |
|
using llama_grammar_stack = std::vector<const llama_grammar_element *>; |
|
|
|
using llama_grammar_rules = std::vector<llama_grammar_rule>; |
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>; |
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>; |
|
|
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); |
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); |
|
|
|
|
|
|
|
|
|
|
|
void llama_grammar_accept( |
|
const llama_grammar_rules & rules, |
|
const llama_grammar_stacks & stacks, |
|
uint32_t chr, |
|
llama_grammar_stacks & stacks_new); |
|
|
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( |
|
const llama_grammar_rules & rules, |
|
const llama_grammar_stack & stack, |
|
const llama_grammar_candidates & candidates); |
|
|
|
struct llama_grammar_parser { |
|
std::map<std::string, uint32_t> symbol_ids; |
|
|
|
llama_grammar_rules rules; |
|
|
|
llama_grammar_stack c_rules() const; |
|
|
|
uint32_t get_symbol_id(const char * src, size_t len); |
|
uint32_t generate_symbol_id(const std::string & base_name); |
|
|
|
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); |
|
|
|
const char * parse_alternates( |
|
const char * src, |
|
const std::string & rule_name, |
|
uint32_t rule_id, |
|
bool is_nested); |
|
|
|
const char * parse_sequence( |
|
const char * src, |
|
const std::string & rule_name, |
|
llama_grammar_rule & rule, |
|
bool is_nested); |
|
|
|
const char * parse_rule(const char * src); |
|
|
|
bool parse(const char * src); |
|
void print(FILE * file); |
|
}; |
|
|
|
struct llama_grammar { |
|
|
|
const llama_vocab * vocab; |
|
|
|
const llama_grammar_rules rules; |
|
llama_grammar_stacks stacks; |
|
|
|
|
|
llama_partial_utf8 partial_utf8; |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
struct llama_grammar * llama_grammar_init_impl( |
|
const struct llama_vocab * vocab, |
|
const llama_grammar_element ** rules, |
|
size_t n_rules, |
|
size_t start_rule_index); |
|
|
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); |
|
|
|
void llama_grammar_free_impl(struct llama_grammar * grammar); |
|
|
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); |
|
|
|
|
|
void llama_grammar_apply_impl( |
|
const struct llama_grammar & grammar, |
|
llama_token_data_array * cur_p); |
|
|
|
void llama_grammar_accept_impl( |
|
struct llama_grammar & grammar, |
|
llama_token token); |
|
|