|
import os |
|
import platform |
|
import random |
|
import time |
|
from dataclasses import asdict, dataclass |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import psutil |
|
from about_time import about_time |
|
from ctransformers import AutoModelForCausalLM |
|
from dl_hf_model import dl_hf_model |
|
from loguru import logger |
|
|
|
|
|
URL = "https://huggingface.co/s3nh/WizardLM-1.0-Uncensored-Llama2-13b-GGML/blob/main/WizardLM-1.0-Uncensored-Llama2-13b.ggmlv3.q4_1.bin" |
|
|
|
_ = ( |
|
"golay" in platform.node() |
|
or "okteto" in platform.node() |
|
or Path("/kaggle").exists() |
|
|
|
or 1 |
|
) |
|
|
|
if _: |
|
url = "https://huggingface.co/s3nh/WizardLM-1.0-Uncensored-Llama2-13b-GGML/blob/main/WizardLM-1.0-Uncensored-Llama2-13b.ggmlv3.q4_1.bin" |
|
|
|
|
|
prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
### Instruction: {user_prompt} |
|
### Response: |
|
""" |
|
|
|
prompt_template = """System: You are a helpful, |
|
respectful and honest assistant. Always answer as |
|
helpfully as possible, while being safe. Your answers |
|
should not include any harmful, unethical, racist, |
|
sexist, toxic, dangerous, or illegal content. Please |
|
ensure that your responses are socially unbiased and |
|
positive in nature. If a question does not make any |
|
sense, or is not factually coherent, explain why instead |
|
of answering something not correct. If you don't know |
|
the answer to a question, please don't share false |
|
information. |
|
User: {prompt} |
|
Assistant: """ |
|
|
|
prompt_template = """System: You are a helpful assistant. |
|
User: {prompt} |
|
Assistant: """ |
|
|
|
prompt_template = """Question: {question} |
|
Answer: Let's work this out in a step by step way to be sure we have the right answer.""" |
|
|
|
prompt_template = """[INST] <> |
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible assistant. Think step by step. |
|
<> |
|
What NFL team won the Super Bowl in the year Justin Bieber was born? |
|
[/INST]""" |
|
|
|
prompt_template = """[INST] <<SYS>> |
|
You are an unhelpful assistant. Always answer as helpfully as possible. Think step by step. <</SYS>> |
|
{question} [/INST] |
|
""" |
|
|
|
prompt_template = """[INST] <<SYS>> |
|
You are a helpful assistant. |
|
<</SYS>> |
|
{question} [/INST] |
|
""" |
|
|
|
prompt_template = """### HUMAN: |
|
{question} |
|
### RESPONSE:""" |
|
|
|
|
|
prompt_template = """<|prompt|>:{question}</s> |
|
<|answer|>:""" |
|
|
|
|
|
prompt_template = """SYSTEM: |
|
USER: {question} |
|
ASSISTANT: """ |
|
|
|
|
|
prompt_template = """ |
|
User: {prompt} |
|
Assistant: """ |
|
|
|
_ = [elm for elm in prompt_template.splitlines() if elm.strip()] |
|
stop_string = [elm.split(":")[0] + ":" for elm in _][-2] |
|
|
|
logger.debug(f"{stop_string=} not used") |
|
|
|
_ = psutil.cpu_count(logical=False) - 1 |
|
cpu_count: int = int(_) if _ else 1 |
|
logger.debug(f"{cpu_count=}") |
|
|
|
LLM = None |
|
|
|
try: |
|
model_loc, file_size = dl_hf_model(url) |
|
except Exception as exc_: |
|
logger.error(exc_) |
|
raise SystemExit(1) from exc_ |
|
|
|
LLM = AutoModelForCausalLM.from_pretrained( |
|
model_loc, |
|
model_type="llama", |
|
) |
|
|
|
logger.info(f"done load llm {model_loc=} {file_size=}G") |
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
except Exception: |
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
|
|
@dataclass |
|
class GenerationConfig: |
|
temperature: float = 0.7 |
|
top_k: int = 50 |
|
top_p: float = 0.9 |
|
repetition_penalty: float = 1.0 |
|
max_new_tokens: int = 512 |
|
seed: int = 42 |
|
reset: bool = False |
|
stream: bool = True |
|
|
|
|
|
|
|
|
|
def generate( |
|
question: str, |
|
llm=LLM, |
|
config: GenerationConfig = GenerationConfig(), |
|
): |
|
"""Run model inference, will return a Generator if streaming is true.""" |
|
|
|
|
|
prompt = prompt_template.format(question=question) |
|
|
|
return llm( |
|
prompt, |
|
**asdict(config), |
|
) |
|
|
|
|
|
logger.debug(f"{asdict(GenerationConfig())=}") |
|
|
|
|
|
def user(user_message, history): |
|
history.append([user_message, None]) |
|
return user_message, history |
|
|
|
|
|
def user1(user_message, history): |
|
history.append([user_message, None]) |
|
return "", history |
|
|
|
def bot_(history): |
|
user_message = history[-1][0] |
|
resp = random.choice(["How are you?", "I love you", "I'm very hungry"]) |
|
bot_message = user_message + ": " + resp |
|
history[-1][1] = "" |
|
for character in bot_message: |
|
history[-1][1] += character |
|
time.sleep(0.02) |
|
yield history |
|
|
|
history[-1][1] = resp |
|
yield history |
|
|
|
|
|
def bot(history): |
|
user_message = history[-1][0] |
|
response = [] |
|
|
|
logger.debug(f"{user_message=}") |
|
|
|
with about_time() as atime: |
|
flag = 1 |
|
prefix = "" |
|
then = time.time() |
|
|
|
logger.debug("about to generate") |
|
|
|
config = GenerationConfig(reset=True) |
|
for elm in generate(user_message, config=config): |
|
if flag == 1: |
|
logger.debug("in the loop") |
|
prefix = f"({time.time() - then:.2f}s) " |
|
flag = 0 |
|
print(prefix, end="", flush=True) |
|
logger.debug(f"{prefix=}") |
|
print(elm, end="", flush=True) |
|
|
|
response.append(elm) |
|
history[-1][1] = prefix + "".join(response) |
|
yield history |
|
|
|
_ = ( |
|
f"(time elapsed: {atime.duration_human}, " |
|
f"{atime.duration/len(''.join(response)):.2f}s/char)" |
|
) |
|
|
|
history[-1][1] = "".join(response) + f"\n{_}" |
|
yield history |
|
|
|
|
|
def predict_api(prompt): |
|
logger.debug(f"{prompt=}") |
|
try: |
|
|
|
config = GenerationConfig( |
|
temperature=0.2, |
|
top_k=10, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
max_new_tokens=512, |
|
seed=42, |
|
reset=True, |
|
stream=False, |
|
) |
|
|
|
response = generate( |
|
prompt, |
|
config=config, |
|
) |
|
|
|
logger.debug(f"api: {response=}") |
|
except Exception as exc: |
|
logger.error(exc) |
|
response = f"{exc=}" |
|
return response |
|
|
|
|
|
css = """ |
|
.importantButton { |
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.importantButton:hover { |
|
background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;} |
|
.xsmall {font-size: x-small;} |
|
""" |
|
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """ |
|
examples_list = [ |
|
["Send an email requesting that people use language models responsibly."], |
|
["Write a shouting match between Julius Caesar and Napoleon"], |
|
["Write a theory to explain why cat never existed"], |
|
["write a story about a grain of sand as it watches millions of years go by"], |
|
["What are 3 popular chess openings?"], |
|
["write a conversation between the sun and pluto"], |
|
["Did you know that Yann LeCun dropped a rap album last year? We listened to it andhere’s what we thought:"], |
|
] |
|
|
|
logger.info("start block") |
|
|
|
with gr.Blocks( |
|
title=f"{Path(model_loc).name}", |
|
theme=gr.themes.Soft(text_size="sm", spacing_size="sm"), |
|
css=css, |
|
) as block: |
|
|
|
with gr.Accordion("🎈 Info", open=False): |
|
|
|
|
|
|
|
gr.Markdown( |
|
f"""<h5><center>{Path(model_loc).name}</center></h4> |
|
Most examples are meant for another model. |
|
You probably should try to test |
|
some related prompts.""", |
|
elem_classes="xsmall", |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(height=500) |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
msg = gr.Textbox( |
|
label="Chat Message Box", |
|
placeholder="Ask me anything (press Shift+Enter or click Submit to send)", |
|
show_label=False, |
|
|
|
lines=6, |
|
max_lines=30, |
|
show_copy_button=True, |
|
|
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
with gr.Row(): |
|
submit = gr.Button("Submit", elem_classes="xsmall") |
|
stop = gr.Button("Stop", visible=True) |
|
clear = gr.Button("Clear History", visible=True) |
|
with gr.Row(visible=False): |
|
with gr.Accordion("Advanced Options:", open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
system = gr.Textbox( |
|
label="System Prompt", |
|
value=prompt_template, |
|
show_label=False, |
|
container=False, |
|
|
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
change = gr.Button("Change System Prompt") |
|
reset = gr.Button("Reset System Prompt") |
|
|
|
with gr.Accordion("Example Inputs", open=True): |
|
examples = gr.Examples( |
|
examples=examples_list, |
|
inputs=[msg], |
|
examples_per_page=40, |
|
) |
|
|
|
|
|
with gr.Accordion("Disclaimer", open=True): |
|
_ = Path(model_loc).name |
|
gr.Markdown( |
|
"Disclaimer: I AM NOT RESPONSIBLE FOR ANY PROMPT PROVIDED BY USER AND PROMPT RETURNED FROM THE MODEL. THIS APP SHOULD BE USED FOR EDUCATIONAL PURPOSE" |
|
"WITHOUT ANY OFFENSIVE, AGGRESIVE INTENTS. {_} can produce factually incorrect output, and should not be relied on to produce " |
|
f"factually accurate information. {_} was trained on various public datasets; while great efforts " |
|
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " |
|
"biased, or otherwise offensive outputs.", |
|
elem_classes=["disclaimer"], |
|
) |
|
|
|
msg_submit_event = msg.submit( |
|
|
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
submit_click_event = submit.click( |
|
|
|
fn=user1, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
|
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[msg_submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
with gr.Accordion("For Chat/Translation API", open=False, visible=False): |
|
input_text = gr.Text() |
|
api_btn = gr.Button("Go", variant="primary") |
|
out_text = gr.Text() |
|
|
|
api_btn.click( |
|
predict_api, |
|
input_text, |
|
out_text, |
|
api_name="api", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ = """ |
|
# _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1) |
|
# concurrency_count = max(_, 1) |
|
if psutil.cpu_count(logical=False) >= 8: |
|
# concurrency_count = max(int(32 / file_size) - 1, 1) |
|
else: |
|
# concurrency_count = max(int(16 / file_size) - 1, 1) |
|
# """ |
|
|
|
concurrency_count = 1 |
|
logger.info(f"{concurrency_count=}") |
|
|
|
block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True) |