File size: 3,649 Bytes
1d896f1
 
 
 
 
 
 
 
 
 
1970af8
 
 
1d896f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1970af8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import spaces
import time
import subprocess

MIN_TOKENS=128
MAX_TOKENS=8192
DEFAULT_TOKENS=2048
DURATION=60

# Install flash attention
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

# Load model and tokenizer once when the app starts
model_token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-128k-instruct",
    token=model_token,
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=model_token)

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define error handling function
def handle_error(error):
    return {"error": str(error)}

# Define chat function with input validation and error handling
@spaces.GPU(duration=DURATION)
def chat(message, history, temperature, do_sample, max_tokens):
    try:
        # Validate input
        if not message:
            raise ValueError("Please enter a message")
        if temperature < 0 or temperature > 1:
            raise ValueError("Temperature must be between 0 and 1")
        if max_tokens < MIN_TOKENS or max_tokens > MAX_TOKENS:
            raise ValueError(f"Max tokens must be between {MIN_TOKENS} and {MAX_TOKENS}")

        # Prepare chat history
        chat = []
        for item in history:
            chat.append({"role": "user", "content": item[0]})
            if item[1] is not None:
                chat.append({"role": "assistant", "content": item[1]})
        chat.append({"role": "user", "content": message})

        # Generate response
        messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([messages], return_tensors="pt").to(device)
        streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
        generate_kwargs = dict(
            model_inputs,
            streamer=streamer,
            max_new_tokens=max_tokens,
            do_sample=do_sample,
            temperature=temperature,
            eos_token_id=[tokenizer.eos_token_id],
        )

        # Generate response
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        # Yield partial responses
        partial_text = ""
        for new_text in streamer:
            partial_text += new_text
            yield partial_text

        # Yield final response
        yield partial_text

    except Exception as e:
        yield handle_error(e)

# Create Gradio interface
demo = gr.ChatInterface(
    fn=chat,
    examples=[["Write me a poem about Machine Learning."]],
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling", value=True),
        gr.Slider(
            minimum=MIN_TOKENS,
            maximum=MAX_TOKENS,
            step=1,
            value=DEFAULT_TOKENS,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)",
)

# Launch Gradio app
demo.launch()