DeepMount00's picture
Update app.py
f24ac48 verified
raw
history blame
8.43 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Constants and model initialization code remains the same
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "DeepMount00/Lexora-Lite-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
model.eval()
CUSTOM_CSS = """
.container {
max-width: 1000px !important;
margin: auto !important;
}
.header {
text-align: center;
margin-bottom: 1rem;
padding: 1rem;
}
.header h1 {
font-size: 2rem;
font-weight: 600;
color: #1e293b;
margin-bottom: 0.5rem;
}
.header p {
color: #64748b;
font-size: 1rem;
}
.chat-container {
border-radius: 0.75rem;
background: white;
box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1);
height: calc(100vh - 200px);
display: flex;
flex-direction: column;
}
.message-container {
padding: 1rem;
margin-bottom: 0.5rem;
}
.user-message {
background: #f8fafc;
border-left: 3px solid #2563eb;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 0.5rem;
}
.assistant-message {
background: white;
border-left: 3px solid #64748b;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 0.5rem;
}
.controls-panel {
position: fixed;
right: 1rem;
top: 1rem;
width: 300px;
background: white;
padding: 1rem;
border-radius: 0.5rem;
box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1);
z-index: 1000;
display: none;
}
.controls-button {
position: fixed;
right: 1rem;
top: 1rem;
z-index: 1001;
background: #2563eb !important;
color: white !important;
padding: 0.5rem 1rem !important;
border-radius: 0.5rem !important;
font-size: 0.875rem !important;
font-weight: 500 !important;
}
.input-area {
border-top: 1px solid #e2e8f0;
padding: 1rem;
background: white;
border-radius: 0 0 0.75rem 0.75rem;
}
.textbox {
border: 1px solid #e2e8f0 !important;
border-radius: 0.5rem !important;
padding: 0.75rem !important;
font-size: 1rem !important;
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05) !important;
}
.textbox:focus {
border-color: #2563eb !important;
outline: none !important;
box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2) !important;
}
.submit-button {
background: #2563eb !important;
color: white !important;
padding: 0.5rem 1rem !important;
border-radius: 0.5rem !important;
font-size: 0.875rem !important;
font-weight: 500 !important;
transition: all 0.2s !important;
}
.submit-button:hover {
background: #1d4ed8 !important;
}
"""
DESCRIPTION = '''
<div class="header">
<h1>Lexora-Lite-3B Chat</h1>
<p>An advanced Italian language model ready to assist you</p>
</div>
'''
# Generate function remains the same
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 2048,
temperature: float = 0.0001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def create_chat_interface():
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
radius_size=gr.themes.sizes.radius_sm,
)
with gr.Blocks(css=CUSTOM_CSS, theme=theme) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
# Main chat column
with gr.Column(scale=3):
chat = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System Message",
visible=False,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=1.0,
step=0.1,
value=0.0001,
visible=False,
),
gr.Slider(
label="Top-p",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
visible=False,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
visible=False,
),
gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
visible=False,
),
],
examples=[
["Ciao! Come stai?"],
["Raccontami una breve storia."],
["Qual è il tuo piatto italiano preferito?"],
],
cache_examples=False,
)
# Advanced settings panel
with gr.Column(scale=1, visible=False) as settings_panel:
gr.Markdown("### Advanced Settings")
gr.Slider(
label="Temperature",
minimum=0,
maximum=1.0,
step=0.1,
value=0.0001,
)
gr.Slider(
label="Top-p",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
)
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
)
gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
)
if __name__ == "__main__":
demo = create_chat_interface()
demo.queue(max_size=20).launch()