|
''' |
|
This file has been 100% copied from this PR to the Transformers library: |
|
https://github.com/huggingface/transformers/pull/27557 |
|
|
|
Author: Saibo-creator |
|
Author GitHub: https://github.com/Saibo-creator |
|
|
|
All credits go to the author. |
|
''' |
|
|
|
import logging |
|
import re |
|
import time |
|
from abc import ABC |
|
from functools import lru_cache |
|
from typing import Dict, List |
|
|
|
import torch |
|
|
|
from modules import shared |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
END_OF_ALTERNATE_MARKER = 0 |
|
END_OF_RULE_MARKER = 0 |
|
TO_BE_FILLED_MARKER = 0 |
|
REF_RULE_MARKER = 1 |
|
LITERAL_MARKER = 2 |
|
|
|
|
|
class ParseState: |
|
def __init__(self): |
|
self.symbol_ids = {} |
|
self.grammar_encoding = [] |
|
|
|
|
|
def get_symbol_id(state, src): |
|
if src not in state.symbol_ids: |
|
state.symbol_ids[src] = len(state.symbol_ids) |
|
return state.symbol_ids[src] |
|
|
|
|
|
def generate_symbol_id(state, base_name): |
|
next_id = len(state.symbol_ids) |
|
state.symbol_ids[base_name + "_" + str(next_id)] = next_id |
|
return next_id |
|
|
|
|
|
def is_word_char(c): |
|
return c.isalnum() or c == "-" or c == "_" |
|
|
|
|
|
def hex_to_int(c): |
|
if c.isdigit(): |
|
return int(c) |
|
elif "a" <= c.lower() <= "f": |
|
return ord(c.lower()) - ord("a") + 10 |
|
return -1 |
|
|
|
|
|
def remove_leading_white_space(src, newline_ok): |
|
""" |
|
Skips over whitespace and comments in the input string. |
|
This function processes the input string, skipping over any spaces, tabs, |
|
and content following a '#' character, which denotes a comment. The parsing |
|
of a comment continues until the end of the line (denoted by newline characters |
|
'\r' or '\n'). If the 'newline_ok' parameter is set to False, the function |
|
will stop processing and return the remaining string upon encountering a |
|
newline character, otherwise it will skip over newline characters as well. |
|
Parameters: |
|
src (str): The input string to be processed. |
|
newline_ok (bool): A flag indicating whether encountering a newline character |
|
should stop the parsing (False) or if it should be skipped (True). |
|
Returns: |
|
str: The remaining portion of the input string after skipping whitespace and comments. |
|
""" |
|
pos = 0 |
|
while pos < len(src) and (src[pos].isspace() or src[pos] == "#"): |
|
if src[pos] == "#": |
|
while pos < len(src) and src[pos] not in ("\r", "\n"): |
|
pos += 1 |
|
else: |
|
if not newline_ok and src[pos] in ("\r", "\n"): |
|
break |
|
pos += 1 |
|
return src[pos:] |
|
|
|
|
|
def parse_name(src): |
|
pos = 0 |
|
while pos < len(src) and is_word_char(src[pos]): |
|
pos += 1 |
|
if pos == 0: |
|
raise RuntimeError("expecting name at " + src) |
|
return src[:pos], src[pos:] |
|
|
|
|
|
def parse_char(src): |
|
""" |
|
parse the leading char from the input string |
|
:param src: |
|
:return: char, remaining_src |
|
""" |
|
|
|
|
|
if src[0] == "\\": |
|
esc = src[1] |
|
if esc == "x": |
|
first = hex_to_int(src[2]) |
|
if first > -1: |
|
second = hex_to_int(src[3]) |
|
if second > -1: |
|
return (first << 4) + second, src[4:] |
|
raise RuntimeError("expecting \\xNN at " + src) |
|
elif esc in ('"', "[", "]"): |
|
return esc, src[2:] |
|
elif esc == "r": |
|
return "\r", src[2:] |
|
elif esc == "n": |
|
return "\n", src[2:] |
|
elif esc == "t": |
|
return "\t", src[2:] |
|
raise RuntimeError("unknown escape at " + src) |
|
elif src: |
|
return src[0], src[1:] |
|
raise RuntimeError("unexpected end of input") |
|
|
|
|
|
def parse_sequence(state, src, rule_name, outbuf, is_nested): |
|
out_start_pos = len(outbuf) |
|
|
|
|
|
outbuf.append(TO_BE_FILLED_MARKER) |
|
|
|
last_sym_start = len(outbuf) |
|
remaining_src = src |
|
while remaining_src: |
|
if remaining_src[0] == '"': |
|
remaining_src = remaining_src[1:] |
|
last_sym_start = len(outbuf) |
|
while remaining_src[0] != '"': |
|
char, remaining_src = parse_char(remaining_src) |
|
|
|
|
|
outbuf.append(LITERAL_MARKER) |
|
outbuf.append(ord(char)) |
|
outbuf.append(ord(char)) |
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
|
elif remaining_src[0] == "[": |
|
remaining_src = remaining_src[1:] |
|
last_sym_start = len(outbuf) |
|
|
|
outbuf.append(TO_BE_FILLED_MARKER) |
|
while remaining_src[0] != "]": |
|
char, remaining_src = parse_char(remaining_src) |
|
|
|
outbuf.append(ord(char)) |
|
if remaining_src[0] == "-" and remaining_src[1] != "]": |
|
endchar_pair, remaining_src = parse_char(remaining_src[1:]) |
|
outbuf.append(ord(endchar_pair)) |
|
else: |
|
|
|
outbuf.append(ord(char)) |
|
|
|
outbuf[last_sym_start] = len(outbuf) - last_sym_start - 1 |
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
|
elif is_word_char(remaining_src[0]): |
|
name, remaining_src = parse_name(remaining_src) |
|
ref_rule_id = get_symbol_id(state, name) |
|
remaining_src = remove_leading_white_space(remaining_src, is_nested) |
|
last_sym_start = len(outbuf) |
|
outbuf.append(REF_RULE_MARKER) |
|
outbuf.append(ref_rule_id) |
|
elif remaining_src[0] == "(": |
|
|
|
remaining_src = remove_leading_white_space(remaining_src[1:], True) |
|
sub_rule_id = generate_symbol_id(state, rule_name) |
|
remaining_src = parse_alternates(state, remaining_src, rule_name, sub_rule_id, True) |
|
last_sym_start = len(outbuf) |
|
|
|
outbuf.append(REF_RULE_MARKER) |
|
outbuf.append(sub_rule_id) |
|
if remaining_src[0] != ")": |
|
raise RuntimeError("expecting ')' at " + remaining_src) |
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
|
elif remaining_src[0] in ("*", "+", "?"): |
|
if len(outbuf) - out_start_pos - 1 == 0: |
|
raise RuntimeError("expecting preceeding item to */+/? at " + remaining_src) |
|
out_grammar = state.grammar_encoding |
|
|
|
|
|
|
|
|
|
|
|
|
|
sub_rule_id = generate_symbol_id(state, rule_name) |
|
out_grammar.append(sub_rule_id) |
|
sub_rule_start = len(out_grammar) |
|
|
|
out_grammar.append(TO_BE_FILLED_MARKER) |
|
|
|
out_grammar.extend(outbuf[last_sym_start:]) |
|
if remaining_src[0] in ("*", "+"): |
|
|
|
out_grammar.append(REF_RULE_MARKER) |
|
out_grammar.append(sub_rule_id) |
|
|
|
out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start |
|
|
|
out_grammar.append(END_OF_ALTERNATE_MARKER) |
|
sub_rule_start = len(out_grammar) |
|
|
|
out_grammar.append(TO_BE_FILLED_MARKER) |
|
if remaining_src[0] == "+": |
|
|
|
out_grammar.extend(outbuf[last_sym_start:]) |
|
|
|
out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start |
|
|
|
out_grammar.append(END_OF_ALTERNATE_MARKER) |
|
out_grammar.append(END_OF_RULE_MARKER) |
|
|
|
|
|
outbuf[last_sym_start:] = [1, sub_rule_id] |
|
|
|
remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
|
else: |
|
break |
|
|
|
outbuf[out_start_pos] = len(outbuf) - out_start_pos |
|
|
|
outbuf.append(END_OF_ALTERNATE_MARKER) |
|
return remaining_src |
|
|
|
|
|
def parse_alternates(state, src, rule_name, rule_id, is_nested): |
|
outbuf = [] |
|
remaining_src = parse_sequence(state, src, rule_name, outbuf, is_nested) |
|
while remaining_src and remaining_src[0] == "|": |
|
remaining_src = remove_leading_white_space(remaining_src[1:], True) |
|
remaining_src = parse_sequence(state, remaining_src, rule_name, outbuf, is_nested) |
|
|
|
state.grammar_encoding.append(rule_id) |
|
state.grammar_encoding.extend(outbuf) |
|
state.grammar_encoding.append(0) |
|
return remaining_src |
|
|
|
|
|
def parse_rule(state, src): |
|
name, remaining_src = parse_name(src) |
|
remaining_src = remove_leading_white_space(remaining_src, False) |
|
rule_id = get_symbol_id(state, name) |
|
|
|
if remaining_src[:3] != "::=": |
|
raise RuntimeError("expecting ::= at " + remaining_src) |
|
remaining_src = remove_leading_white_space(remaining_src[3:], True) |
|
|
|
remaining_src = parse_alternates(state, remaining_src, name, rule_id, False) |
|
|
|
if remaining_src and remaining_src[0] == "\r": |
|
remaining_src = remaining_src[2:] if remaining_src[1] == "\n" else remaining_src[1:] |
|
elif remaining_src and remaining_src[0] == "\n": |
|
remaining_src = remaining_src[1:] |
|
elif remaining_src: |
|
raise RuntimeError("expecting newline or end at " + remaining_src) |
|
return remove_leading_white_space(remaining_src, True) |
|
|
|
|
|
def parse_ebnf(src): |
|
try: |
|
state = ParseState() |
|
grammar_repr = remove_leading_white_space(src, True) |
|
last_grammar_repr = "" |
|
while grammar_repr: |
|
if last_grammar_repr: |
|
last_parsed_rule_len = len(last_grammar_repr) - len(grammar_repr) |
|
logger.debug(f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}") |
|
last_grammar_repr = grammar_repr |
|
grammar_repr = parse_rule(state, grammar_repr) |
|
state.grammar_encoding.append(0xFFFF) |
|
return state |
|
except RuntimeError as err: |
|
logger.warning("error parsing grammar:", err) |
|
return ParseState() |
|
|
|
|
|
def print_rule(file, grammar_encoding, index, symbol_id_names): |
|
rule_id = grammar_encoding[index] |
|
print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file) |
|
pos = index + 1 |
|
while grammar_encoding[pos]: |
|
if pos - 1 > index: |
|
print("|", end=" ", file=file) |
|
pos += 1 |
|
while grammar_encoding[pos]: |
|
if grammar_encoding[pos] == REF_RULE_MARKER: |
|
ref_rule_id = grammar_encoding[pos + 1] |
|
print( |
|
f"<{pos}>{symbol_id_names[ref_rule_id]}", |
|
end=" ", |
|
file=file, |
|
) |
|
pos += 2 |
|
else: |
|
print("<{}>[".format(pos), end="", file=file) |
|
num_chars = grammar_encoding[pos] |
|
pos += 1 |
|
|
|
for i in range(0, num_chars, 2): |
|
print("{}-".format(chr(grammar_encoding[pos + i])), end="", file=file) |
|
if i + 1 < num_chars: |
|
print("{}".format(chr(grammar_encoding[pos + i + 1])), end="", file=file) |
|
print("]", end=" ", file=file) |
|
pos += num_chars |
|
pos += 1 |
|
print(file=file) |
|
return pos + 1 |
|
|
|
|
|
def print_grammar(file, state): |
|
pos = 0 |
|
symbol_id_names = {v: k for k, v in state.symbol_ids.items()} |
|
print("Grammar Rules:", file=file) |
|
|
|
while state.grammar_encoding[pos] != 0xFFFF: |
|
pos = print_rule(file, state.grammar_encoding, pos, symbol_id_names) |
|
pos = 0 |
|
print("\nBinary representation:", file=file) |
|
while state.grammar_encoding[pos] != 0xFFFF: |
|
print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file) |
|
pos += 1 |
|
print("ffff\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GrammarConstraint(ABC): |
|
def __init__(self, grammar_str, start_rule_name, tokenizer): |
|
self.tt = 0 |
|
self.nt = 0 |
|
state = parse_ebnf(grammar_str) |
|
grammar_encoding = state.grammar_encoding |
|
self.start_rule_id = state.symbol_ids.get(start_rule_name) |
|
|
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.token_trie = TokenTrie(tokenizer) |
|
self.tokenizer = tokenizer |
|
self.grammar_encoding = grammar_encoding |
|
|
|
pos = 0 |
|
rules: Dict[int, int] = {} |
|
|
|
while grammar_encoding[pos] != 0xFFFF: |
|
rule_id = grammar_encoding[pos] |
|
|
|
|
|
|
|
rules[rule_id] = pos |
|
pos += 1 |
|
|
|
|
|
|
|
|
|
while grammar_encoding[pos]: |
|
pos += 1 + grammar_encoding[pos] |
|
|
|
|
|
pos += 1 |
|
|
|
self.start_rule_pos = rules[self.start_rule_id] |
|
self.rules_pos_dict: Dict[int, int] = rules |
|
|
|
def init_stacks(self): |
|
|
|
|
|
|
|
|
|
stack = [self.start_rule_pos + 2] |
|
|
|
return self.advance_stack(tuple(stack)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=32768) |
|
def advance_stack(self, stack): |
|
stack = list(stack) |
|
|
|
if len(stack) == 0: |
|
return [stack] |
|
|
|
|
|
pos = stack[-1] |
|
|
|
|
|
|
|
if self.grammar_encoding[pos] > 1: |
|
return [stack] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
referenced_rule_id = self.grammar_encoding[pos + 1] |
|
|
|
|
|
subpos = self.rules_pos_dict[referenced_rule_id] + 1 |
|
stacks: List[List[int]] = [] |
|
|
|
|
|
|
|
|
|
while self.grammar_encoding[subpos]: |
|
new_stack = stack[:-1] |
|
if self.grammar_encoding[pos + 2]: |
|
|
|
|
|
new_stack.append(pos + 2) |
|
|
|
|
|
if self.grammar_encoding[subpos + 1]: |
|
new_stack.append(subpos + 1) |
|
stacks.extend(self.advance_stack(tuple(new_stack))) |
|
|
|
|
|
subpos += self.grammar_encoding[subpos] + 1 |
|
return stacks |
|
|
|
def accept_char(self, *args, **kwargs): |
|
"""Process a byte according to the grammar rules.""" |
|
raise NotImplementedError |
|
|
|
def accept_token_id(self, *args, **kwargs): |
|
"""Process a token according to the grammar rules.""" |
|
raise NotImplementedError |
|
|
|
def filter_vocab(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
|
|
class IncrementalGrammarConstraint(GrammarConstraint): |
|
def __init__(self, grammar_str, start_rule_name, tokenizer): |
|
super().__init__(grammar_str, start_rule_name, tokenizer) |
|
|
|
def accept_char(self, byte, stacks): |
|
new_stacks = [] |
|
for stack in stacks: |
|
|
|
if not stack: |
|
continue |
|
|
|
pos = stack[-1] |
|
num_chars = self.grammar_encoding[pos] |
|
|
|
|
|
pos += 1 |
|
found = False |
|
for i in range(0, num_chars, 2): |
|
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: |
|
found = True |
|
break |
|
if not found: |
|
continue |
|
|
|
pos += num_chars |
|
new_stack = stack[:-1] |
|
if self.grammar_encoding[pos]: |
|
new_stack.append(pos) |
|
new_stacks.extend(self.advance_stack(tuple(new_stack))) |
|
|
|
return new_stacks |
|
|
|
def accept_string(self, string: str, stacks: List[List[int]]): |
|
_bytes = bytes(string, "utf-8") |
|
for byte in _bytes: |
|
stacks = self.accept_char(byte, stacks) |
|
return stacks |
|
|
|
def accept_token_id(self, token_id: int, stacks: List[List[int]]): |
|
if token_id == self.eos_token_id: |
|
if stacks and all(len(stack) != 0 for stack in stacks): |
|
raise Exception( |
|
f"At least one of the stack should be empty when EOS is reached. However, " |
|
f"the stacks are {stacks}" |
|
) |
|
return [] |
|
|
|
for byte in self.token_trie.id2str(token_id): |
|
stacks = self.accept_char(byte, stacks) |
|
|
|
|
|
|
|
|
|
|
|
return stacks |
|
|
|
def accept_token_ids(self, token_ids: List[int], stacks: List[List[int]], as_string=True): |
|
if as_string: |
|
string = self.tokenizer.decode(token_ids) |
|
stacks = self.accept_string(string, stacks) |
|
else: |
|
for token_id in token_ids: |
|
stacks = self.accept_token_id(token_id, stacks) |
|
return stacks |
|
|
|
def batch_filter_vocab(self, batch_stacks, device): |
|
batch_acceptance = [] |
|
for stacks in batch_stacks: |
|
batch_acceptance.append(self.filter_vocab(stacks, device)) |
|
return torch.stack(batch_acceptance) |
|
|
|
def filter_vocab(self, stacks, device): |
|
if not stacks: |
|
|
|
|
|
vocab_size = len(self.token_trie) |
|
logger.debug(f"sum of acceptance: {0}") |
|
return torch.zeros(vocab_size, dtype=torch.bool, device=device) |
|
|
|
acceptance_matrix = torch.cat([self.token_acceptance_for_stack(tuple(stack), device) for stack in stacks]) |
|
|
|
acceptance = acceptance_matrix.reshape(len(stacks), -1).any(dim=0) |
|
logger.debug(f"sum of acceptance: {acceptance.sum()}") |
|
return acceptance |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def pos_char_acceptance(self, pos): |
|
acceptance = [False] * 256 |
|
num_chars = self.grammar_encoding[pos] |
|
pos += 1 |
|
for i in range(0, num_chars, 2): |
|
start = self.grammar_encoding[pos + i] |
|
end = self.grammar_encoding[pos + i + 1] |
|
for j in range(start, end + 1): |
|
acceptance[j] = True |
|
return acceptance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=32768) |
|
def token_acceptance_for_stack(self, stack, device): |
|
st = time.time() |
|
stack = list(stack) |
|
|
|
accepts = [False] * len(self.token_trie) |
|
accepts[self.eos_token_id] = len(stack) == 0 |
|
if len(stack) == 0: |
|
logger.debug("empty stack") |
|
|
|
def traverse_trie(trie, stacks): |
|
for byte, next_trie in trie.items(): |
|
if byte == LEAF: |
|
token_id = next_trie |
|
if token_id != self.eos_token_id: |
|
accepts[token_id] = bool(stacks) |
|
continue |
|
|
|
new_stacks = [] |
|
for stk in stacks: |
|
if not stk: |
|
continue |
|
|
|
pos = stk[-1] |
|
num_chars = self.grammar_encoding[pos] |
|
|
|
if not self.pos_char_acceptance(pos)[byte]: |
|
continue |
|
|
|
pos += num_chars + 1 |
|
new_stack = stk[:-1] |
|
if self.grammar_encoding[pos]: |
|
new_stack.append(pos) |
|
new_stacks.extend(self.advance_stack(tuple(new_stack))) |
|
|
|
if new_stacks: |
|
traverse_trie(next_trie, new_stacks) |
|
|
|
traverse_trie(self.token_trie.trie, [stack]) |
|
|
|
et = time.time() - st |
|
x = torch.tensor(accepts, dtype=torch.bool, device=device) |
|
self.tt += et |
|
self.nt += 1 |
|
return x |
|
|
|
|
|
class StaticGrammarConstraint(GrammarConstraint): |
|
def __init__(self, grammar_str, start_rule_name, tokenizer): |
|
super().__init__(grammar_str, start_rule_name, tokenizer) |
|
|
|
def accept_char(self): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LEAF = -1 |
|
|
|
|
|
class TokenTrie: |
|
def __init__(self, tokenizer): |
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.tokens = [] |
|
self.trie = {} |
|
self.load_tokens(tokenizer) |
|
|
|
def id2str(self, token_id): |
|
return self.tokens[token_id] |
|
|
|
def __len__(self): |
|
return len(self.tokens) |
|
|
|
def load_tokens(self, tokenizer): |
|
def replace_hex(match): |
|
hex_value = match.group(1) |
|
return chr(int(hex_value, 16)) |
|
|
|
if "gpt2" in tokenizer.__class__.__name__.lower(): |
|
special = tokenizer.additional_special_tokens_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fmt_token(id): |
|
if id in special: |
|
return None |
|
return bytes(tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8") |
|
|
|
elif "llama" in tokenizer.__class__.__name__.lower(): |
|
|
|
def fmt_token(id): |
|
token = tokenizer.convert_ids_to_tokens(id) |
|
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) |
|
token = token.replace("โ", " ") |
|
return bytes(token, "utf-8") |
|
|
|
else: |
|
print("Warning: unrecognized tokenizer: using default token formatting") |
|
|
|
def fmt_token(id): |
|
token = tokenizer.convert_ids_to_tokens(id) |
|
return bytes(token, "utf-8") |
|
|
|
|
|
|
|
self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] |
|
for token_id, token_bytes in enumerate(self.tokens): |
|
if token_bytes is not None: |
|
self.insert_into_trie(self.trie, token_bytes, token_id) |
|
|
|
def insert_into_trie(self, trie, token_bytes, token_id): |
|
current = trie |
|
for byte in token_bytes: |
|
if byte not in current: |
|
current[byte] = {} |
|
current = current[byte] |
|
current[LEAF] = token_id |
|
|
|
|
|
@lru_cache(maxsize=5) |
|
def initialize_grammar(grammar_string): |
|
return IncrementalGrammarConstraint(grammar_string.strip(), start_rule_name="root", tokenizer=shared.tokenizer) |
|
|