File size: 1,623 Bytes
2b6b154
 
 
 
d24a963
 
1b3204d
 
00df41d
d24a963
145ecb9
 
 
 
ac962aa
145ecb9
2b6b154
 
4581f8a
145ecb9
4581f8a
145ecb9
 
 
 
 
 
 
 
 
 
 
afd19d1
145ecb9
 
4581f8a
145ecb9
 
afd19d1
145ecb9
 
 
2b6b154
09e1b8b
145ecb9
1a33cc7
1b3204d
3b50be6
5d32a98
129a216
5d32a98
4581f8a
 
 
 
 
36cb54d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import spaces
import torch

import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_name,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cpu",
)

@spaces.GPU
def chat_function(message, history, system_prompt,max_new_tokens,temperature):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": message},
    ]
    prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    temp = temperature + 0.1
    outputs = pipeline(
        prompt,
        max_new_tokens=max_new_tokens,
        eos_token_id=terminators,
        do_sample=True,
        temperature=temp,
        top_p=0.9,
    )
    return outputs[0]["generated_text"][len(prompt):]

gr.ChatInterface(
    chat_function,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
    title="Meta-Llama-3-8B-Instruct",
    description="""
    To Learn about Fine-tuning Llama-3-8B, Check https://exnrt.com/blog/ai/finetune-llama3-8b/.
    """,
    additional_inputs=[
        gr.Textbox("You are helpful AI.", label="System Prompt"),
        gr.Slider(512, 4096, label="Max New Tokens"),
        gr.Slider(0, 1, label="Temperature")
    ]
).launch()