File size: 3,692 Bytes
b8c24aa
3a82207
63b82b4
 
 
 
 
 
c8fdb3b
3a82207
4e81072
7dc3087
08c1bd3
4e81072
7dc3087
 
8ea3940
7dc3087
 
63b82b4
9e5e37f
63b82b4
9e5e37f
ea9c0d3
7115ad7
 
ea9c0d3
7dc3087
64d8a64
63b82b4
64d8a64
 
63b82b4
64d8a64
63b82b4
fccbbf3
63b82b4
 
08c1bd3
5e407f5
ea9c0d3
741f665
 
 
 
 
 
 
7dc3087
3a82207
741f665
 
 
 
 
 
 
 
 
 
 
3a82207
63b82b4
 
3a82207
 
 
63b82b4
3a82207
63b82b4
3c2563a
533ccb0
ea9c0d3
3a82207
ea9c0d3
 
 
 
3a82207
 
 
 
7dc3087
3a82207
7dc3087
 
3a82207
 
 
7dc3087
 
 
3a82207
63b82b4
e9cb74c
3a82207
63b82b4
 
 
3c2563a
e2534da
63b82b4
 
 
 
 
3c2563a
63b82b4
6cad6fb
63b82b4
 
 
 
ea9c0d3
63b82b4
 
 
 
9a34670
63b82b4
d0dec2c
3a82207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time

token = os.environ["HF_TOKEN"]

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    "KissanAI/llama3-8b-dhenu-0.1-sft-16bit", quantization_config=quantization_config, token=token
)
tok = AutoTokenizer.from_pretrained("KissanAI/llama3-8b-dhenu-0.1-sft-16bit", token=token)
terminators = [
    tok.eos_token_id,
    tok.convert_tokens_to_ids("<|eot_id|>")
]

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# model = model.to(device)
# Dispatch Errors


@spaces.GPU()
def chat(message, history, temperature,do_sample, max_tokens):
    prompt_template = """
    You are a helpful Agricultural assistant for farmers. You are given the following input. Please complete the response briefly.
    ## Question:
    {}
    
    ## Response:
    {}"""
    start_time = time.time()
    chat = []
    # for item in history:
    #     chat.append({"role": "user", "content": item[0]})
    #     if item[1] is not None:
    #         chat.append({"role": "assistant", "content": item[1]})
    # chat.append({"role": "user", "content": message})
    # messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    
    model_inputs = tok(prompt_template.format(
        message, #input
        "" # response
    ), return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        repetition_penalty=1.2, 
        use_cache=False,
        eos_token_id=terminators,
    )
    
    if temperature == 0:
        generate_kwargs['do_sample'] = False
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    first_token_time = None
    for new_text in streamer:
        if not first_token_time:
            first_token_time = time.time() - start_time
        partial_text += new_text
        yield partial_text

    total_time = time.time() - start_time
    tokens = len(tok.tokenize(partial_text))
    tokens_per_second = tokens / total_time if total_time > 0 else 0

    timing_info = f"\n\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
    yield partial_text +  timing_info


demo = gr.ChatInterface(
    fn=chat,
    examples=[["I'm a farmer from Odisha, how do I take care of whitefly in my cotton crop?"]],
    # multimodal=False,
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling",value=False),
        gr.Slider(
            minimum=128,
            maximum=4096,
            step=1,
            value=512,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [KissanAI/llama3-8b-dhenu-0.1-sft-16bit](https://huggingface.co/KissanAI/llama3-8b-dhenu-0.1-sft-16bit) in 4bit")
demo.launch()