File size: 5,081 Bytes
0fafb5e
6ac164c
37ca5d0
91b03f9
37ca5d0
 
 
 
b17ecc2
a445827
00a2ac7
 
a445827
37ca5d0
 
 
 
 
 
e3f498d
37ca5d0
 
 
 
55bc99c
37ca5d0
 
 
b17ecc2
37ca5d0
a445827
37ca5d0
77b3a6a
5aae577
00a2ac7
37ca5d0
 
00a2ac7
37ca5d0
77b3a6a
37ca5d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a2ac7
 
 
 
 
 
37ca5d0
 
 
00a2ac7
 
 
37ca5d0
00a2ac7
 
 
37ca5d0
708397c
37ca5d0
00a2ac7
37ca5d0
 
 
 
 
 
 
00a2ac7
6ac164c
37ca5d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac164c
37ca5d0
 
 
 
e3f498d
37ca5d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


# Set the environment variable
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

DESCRIPTION = """\
# Llama 3.2 3B Instruct
Llama 3.2 3B is Meta's latest iteration of open LLMs.
This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/llama32).
"""

# Access token for the model (if required)
access_token = os.getenv('HF_TOKEN')
# Download the Base model
#model_id = "./models/Llama-32-3B-Instruct"
model_id = "Mikhil-jivus/Llama-32-8B-FineTuned-Instruct-v1"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#model_id = "nltpt/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id,token=access_token)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    torch_dtype=torch.bfloat16,
    token=access_token
)
model.eval()


@spaces.GPU(duration=90)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = [{"role": "system", "content": system_prompt}]
    for user, assistant in chat_history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": message})

    # Set pad_token_id if it's not already set
    if tokenizer.pad_token_id is None:
        tokenizer.padding_side = 'right'
        tokenizer.pad_token = tokenizer.eos_token

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True,add_special_tokens=True, return_tensors="pt",padding=True ,return_attention_mask=True)
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    # Ensure attention mask is set
    #attention_mask = input_ids['attention_mask']

    input_ids = input_ids.to(model.device)
    #attention_mask = attention_mask.to(model.device)



    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(
            label="System Prompt",
            placeholder="Enter system prompt here...",
            lines=2,
        ),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
    cache_examples=False,
)

with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()