Spaces:
Running
Running
import torch | |
from modules import chat, shared | |
from modules.text_generation import ( | |
decode, | |
encode, | |
generate_reply, | |
) | |
from transformers import LogitsProcessor | |
import gradio as gr | |
params = { | |
"display_name": "Long replies", | |
"is_tab": False, | |
"min_length": 120, | |
} | |
initial_size = 0 | |
class MyLogits(LogitsProcessor): | |
""" | |
Manipulates the probabilities for the next token before it gets sampled. | |
Used in the logits_processor_modifier function below. | |
""" | |
def __init__(self): | |
self.newline_id = shared.tokenizer.encode('\n')[-1] | |
pass | |
def __call__(self, input_ids, scores): | |
if input_ids.shape[-1] - initial_size < params["min_length"]: | |
scores[...,self.newline_id] = -1000 | |
# scores[...,shared.tokenizer.eos_token_id] = -1000 | |
# probs = torch.softmax(scores, dim=-1, dtype=torch.float) | |
# probs[0] /= probs[0].sum() | |
# scores = torch.log(probs / (1 - probs)) | |
return scores | |
def history_modifier(history): | |
""" | |
Modifies the chat history. | |
Only used in chat mode. | |
""" | |
return history | |
def state_modifier(state): | |
""" | |
Modifies the state variable, which is a dictionary containing the input | |
values in the UI like sliders and checkboxes. | |
""" | |
return state | |
def chat_input_modifier(text, visible_text, state): | |
""" | |
Modifies the user input string in chat mode (visible_text). | |
You can also modify the internal representation of the user | |
input (text) to change how it will appear in the prompt. | |
""" | |
return text, visible_text | |
def input_modifier(string, state): | |
""" | |
In default/notebook modes, modifies the whole prompt. | |
In chat mode, it is the same as chat_input_modifier but only applied | |
to "text", here called "string", and not to "visible_text". | |
""" | |
return string | |
def bot_prefix_modifier(string, state): | |
""" | |
Modifies the prefix for the next bot reply in chat mode. | |
By default, the prefix will be something like "Bot Name:". | |
""" | |
return string | |
def tokenizer_modifier(state, prompt, input_ids, input_embeds): | |
""" | |
Modifies the input ids and embeds. | |
Used by the multimodal extension to put image embeddings in the prompt. | |
Only used by loaders that use the transformers library for sampling. | |
""" | |
global initial_size | |
initial_size = input_ids.shape[-1] | |
return prompt, input_ids, input_embeds | |
def logits_processor_modifier(processor_list, input_ids): | |
""" | |
Adds logits processors to the list, allowing you to access and modify | |
the next token probabilities. | |
Only used by loaders that use the transformers library for sampling. | |
""" | |
processor_list.append(MyLogits()) | |
return processor_list | |
def output_modifier(string, state): | |
""" | |
Modifies the LLM output before it gets presented. | |
In chat mode, the modified version goes into history['visible'], | |
and the original version goes into history['internal']. | |
""" | |
return string | |
def custom_generate_chat_prompt(user_input, state, **kwargs): | |
""" | |
Replaces the function that generates the prompt from the chat history. | |
Only used in chat mode. | |
""" | |
result = chat.generate_chat_prompt(user_input, state, **kwargs) | |
return result | |
def custom_css(): | |
""" | |
Returns a CSS string that gets appended to the CSS for the webui. | |
""" | |
return '' | |
def custom_js(): | |
""" | |
Returns a javascript string that gets appended to the javascript | |
for the webui. | |
""" | |
return '' | |
def setup(): | |
""" | |
Gets executed only once, when the extension is imported. | |
""" | |
pass | |
def ui(): | |
""" | |
Gets executed when the UI is drawn. Custom gradio elements and | |
their corresponding event handlers should be defined here. | |
To learn about gradio components, check out the docs: | |
https://gradio.app/docs/ | |
""" | |
min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length') | |
min_length.change(lambda x: params.update({'min_length': x}), min_length, None) | |