truongghieu commited on
Commit
d49f601
1 Parent(s): 4bb61f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -2
app.py CHANGED
@@ -1,3 +1,133 @@
1
- import gradio as gr
 
 
 
2
 
3
- gr.Interface.load("truongghieu/deci-finetuned").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
 
6
+ token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
7
+
8
+ model_id = 'truongghieu/deci-finetuned'
9
+
10
+ SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
11
+ ### Instruction:
12
+ {instruction}
13
+ ### Response:
14
+ """
15
+
16
+ DESCRIPTION = """
17
+ # <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p>
18
+ <span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model; you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span>
19
+ <p><span style='color: #292b47;'>Learn more about the base model <a href="https://deci.ai/blog/decilm-15-times-faster-than-llama2-nas-generated-llm-with-variable-gqa/" style="color: #3264ff;">DeciLM-6B.</a></span></p>
20
+ """
21
+
22
+ if not torch.cuda.is_available():
23
+ DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb'
24
+
25
+ if torch.cuda.is_available():
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ device_map='auto',
30
+ trust_remote_code=True,
31
+ use_auth_token=token
32
+ )
33
+ else:
34
+ model = None
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+
39
+ # Function to construct the prompt using the new system prompt template
40
+ def get_prompt_with_template(message: str) -> str:
41
+ return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)
42
+
43
+ # Function to generate the model's response
44
+ def generate_model_response(message: str) -> str:
45
+ prompt = get_prompt_with_template(message)
46
+ inputs = tokenizer(prompt, return_tensors='pt')
47
+ if torch.cuda.is_available():
48
+ inputs = inputs.to('cuda')
49
+ # Include **generate_kwargs to include the user-defined options
50
+ output = model.generate(**inputs,
51
+ max_new_tokens=3000,
52
+ num_beams=2,
53
+ no_repeat_ngram_size=4,
54
+ early_stopping=True,
55
+ do_sample=True
56
+ )
57
+ return tokenizer.decode(output[0], skip_special_tokens=True)
58
+
59
+ # Function to extract the content after "### Response:"
60
+ def extract_response_content(full_response: str, ) -> str:
61
+ response_start_index = full_response.find("### Response:")
62
+ if response_start_index != -1:
63
+ return full_response[response_start_index + len("### Response:"):].strip()
64
+ else:
65
+ return full_response
66
+
67
+ # The main function that uses the dynamic generate_kwargs
68
+ def get_response_with_template(message: str) -> str:
69
+ full_response = generate_model_response(message)
70
+ return extract_response_content(full_response)
71
+
72
+ with gr.Blocks(css="style.css") as demo:
73
+ gr.Markdown(DESCRIPTION)
74
+ gr.DuplicateButton(value='Duplicate Space for private use',
75
+ elem_id='duplicate-button')
76
+ with gr.Group():
77
+ chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:')
78
+ with gr.Row():
79
+ textbox = gr.Textbox(
80
+ container=False,
81
+ show_label=False,
82
+ placeholder='Type an instruction...',
83
+ scale=10,
84
+ elem_id="textbox"
85
+ )
86
+ submit_button = gr.Button(
87
+ '💬 Submit',
88
+ variant='primary',
89
+ scale=1,
90
+ min_width=0,
91
+ elem_id="submit_button"
92
+ )
93
+
94
+ # Clear button to clear the chat history
95
+ clear_button = gr.Button(
96
+ '🗑️ Clear',
97
+ variant='secondary',
98
+ )
99
+
100
+ clear_button.click(
101
+ fn=lambda: ('',''),
102
+ outputs=[textbox, chatbot],
103
+ queue=False,
104
+ api_name=False,
105
+ )
106
+
107
+ submit_button.click(
108
+ fn=get_response_with_template,
109
+ inputs=textbox,
110
+ outputs= chatbot,
111
+ queue=False,
112
+ api_name=False,
113
+ )
114
+
115
+ gr.Examples(
116
+ examples=[
117
+ 'Write detailed instructions for making chocolate chip pancakes.',
118
+ 'Write a 250-word article about your love of pancakes.',
119
+ 'Explain the plot of Back to the Future in three sentences.',
120
+ 'How do I make a trap beat?',
121
+ 'A step-by-step guide to learning Python in one month.',
122
+ ],
123
+ inputs=textbox,
124
+ outputs=chatbot,
125
+ fn=get_response_with_template,
126
+ cache_examples=True,
127
+ elem_id="examples"
128
+ )
129
+
130
+
131
+ gr.HTML(label="Keep in touch", value="<img src='https://huggingface.co/spaces/Deci/DeciLM-6b-instruct/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")
132
+
133
+ demo.launch()