mateoluksenberg commited on
Commit
5ba94c6
1 Parent(s): 43a1f6e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ import os
7
+ from threading import Thread
8
+
9
+ import pymupdf
10
+ import docx
11
+ from pptx import Presentation
12
+
13
+
14
+ MODEL_LIST = ["nikravan/glm-4vq"]
15
+
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+ MODEL_ID = MODEL_LIST[0]
18
+ MODEL_NAME = "GLM-4vq"
19
+
20
+ TITLE = "<h1>3ML-bot</h1>"
21
+
22
+ DESCRIPTION = f"""
23
+ <center>
24
+ <p>😊 A Multi-Modal Multi-Lingual(3ML) Chat.
25
+ <br>
26
+ 🚀 MODEL NOW: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
27
+ </center>"""
28
+
29
+ CSS = """
30
+ h1 {
31
+ text-align: center;
32
+ display: block;
33
+ }
34
+ """
35
+
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
38
+
39
+
40
+
41
+ def extract_text(path):
42
+ return open(path, 'r').read()
43
+
44
+
45
+ def extract_pdf(path):
46
+ doc = pymupdf.open(path)
47
+ text = ""
48
+ for page in doc:
49
+ text += page.get_text()
50
+ return text
51
+
52
+
53
+ def extract_docx(path):
54
+ doc = docx.Document(path)
55
+ data = []
56
+ for paragraph in doc.paragraphs:
57
+ data.append(paragraph.text)
58
+ content = '\n\n'.join(data)
59
+ return content
60
+
61
+
62
+ def extract_pptx(path):
63
+ prs = Presentation(path)
64
+ text = ""
65
+ for slide in prs.slides:
66
+ for shape in slide.shapes:
67
+ if hasattr(shape, "text"):
68
+ text += shape.text + "\n"
69
+ return text
70
+
71
+
72
+ def mode_load(path):
73
+ choice = ""
74
+ file_type = path.split(".")[-1]
75
+ print(file_type)
76
+ if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
77
+ if file_type.endswith("pdf"):
78
+ content = extract_pdf(path)
79
+ elif file_type.endswith("docx"):
80
+ content = extract_docx(path)
81
+ elif file_type.endswith("pptx"):
82
+ content = extract_pptx(path)
83
+ else:
84
+ content = extract_text(path)
85
+ choice = "doc"
86
+ print(content[:100])
87
+ return choice, content[:5000]
88
+
89
+
90
+ elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
91
+ content = Image.open(path).convert('RGB')
92
+ choice = "image"
93
+ return choice, content
94
+
95
+ else:
96
+ raise gr.Error("Oops, unsupported files.")
97
+
98
+
99
+ @spaces.GPU()
100
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ MODEL_ID,
104
+ torch_dtype=torch.bfloat16,
105
+ low_cpu_mem_usage=True,
106
+ trust_remote_code=True
107
+ )
108
+
109
+ print(f'message is - {message}')
110
+ print(f'history is - {history}')
111
+ conversation = []
112
+ prompt_files = []
113
+ if message["files"]:
114
+ choice, contents = mode_load(message["files"][-1])
115
+ if choice == "image":
116
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
117
+ elif choice == "doc":
118
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
119
+ conversation.append({"role": "user", "content": format_msg})
120
+ else:
121
+ if len(history) == 0:
122
+ # raise gr.Error("Please upload an image first.")
123
+ contents = None
124
+ conversation.append({"role": "user", "content": message['text']})
125
+ else:
126
+ # image = Image.open(history[0][0][0])
127
+ for prompt, answer in history:
128
+ if answer is None:
129
+ prompt_files.append(prompt[0])
130
+ conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
131
+ else:
132
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
133
+ if len(prompt_files) > 0:
134
+ choice, contents = mode_load(prompt_files[-1])
135
+ else:
136
+ choice = ""
137
+ conversation.append({"role": "user", "image": "", "content": message['text']})
138
+
139
+
140
+ if choice == "image":
141
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
142
+ elif choice == "doc":
143
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
144
+ conversation.append({"role": "user", "content": format_msg})
145
+ print(f"Conversation is -\n{conversation}")
146
+
147
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
148
+ return_tensors="pt", return_dict=True).to(model.device)
149
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
150
+
151
+ generate_kwargs = dict(
152
+ max_length=max_length,
153
+ streamer=streamer,
154
+ do_sample=True,
155
+ top_p=top_p,
156
+ top_k=top_k,
157
+ temperature=temperature,
158
+ repetition_penalty=penalty,
159
+ eos_token_id=[151329, 151336, 151338],
160
+ )
161
+ gen_kwargs = {**input_ids, **generate_kwargs}
162
+
163
+ with torch.no_grad():
164
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
165
+ thread.start()
166
+ buffer = ""
167
+ for new_text in streamer:
168
+ buffer += new_text
169
+ yield buffer
170
+
171
+
172
+ chatbot = gr.Chatbot(
173
+ #rtl=True,
174
+ )
175
+ chat_input = gr.MultimodalTextbox(
176
+ interactive=True,
177
+ placeholder="Enter message or upload a file ...",
178
+ show_label=False,
179
+ #rtl=True,
180
+
181
+
182
+
183
+ )
184
+ EXAMPLES = [
185
+ [{"text": "Write a poem about spring season in French Language", }],
186
+ [{"text": "what does this chart mean?", "files": ["sales.png"]}],
187
+ [{"text": "¿Qué está escrito a mano en esta foto?", "files": ["receipt1.png"]}],
188
+ [{"text": "در مورد این عکس توضیح بده و بگو این چه فصلی می تواند باشد", "files": ["nature.jpg"]}]
189
+ ]
190
+
191
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
192
+ gr.HTML(TITLE)
193
+ gr.HTML(DESCRIPTION)
194
+ gr.ChatInterface(
195
+ fn=stream_chat,
196
+ multimodal=True,
197
+
198
+
199
+ textbox=chat_input,
200
+ chatbot=chatbot,
201
+ fill_height=True,
202
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
203
+ additional_inputs=[
204
+ gr.Slider(
205
+ minimum=0,
206
+ maximum=1,
207
+ step=0.1,
208
+ value=0.8,
209
+ label="Temperature",
210
+ render=False,
211
+ ),
212
+ gr.Slider(
213
+ minimum=1024,
214
+ maximum=8192,
215
+ step=1,
216
+ value=4096,
217
+ label="Max Length",
218
+ render=False,
219
+ ),
220
+ gr.Slider(
221
+ minimum=0.0,
222
+ maximum=1.0,
223
+ step=0.1,
224
+ value=1.0,
225
+ label="top_p",
226
+ render=False,
227
+ ),
228
+ gr.Slider(
229
+ minimum=1,
230
+ maximum=20,
231
+ step=1,
232
+ value=10,
233
+ label="top_k",
234
+ render=False,
235
+ ),
236
+ gr.Slider(
237
+ minimum=0.0,
238
+ maximum=2.0,
239
+ step=0.1,
240
+ value=1.0,
241
+ label="Repetition penalty",
242
+ render=False,
243
+ ),
244
+ ],
245
+ ),
246
+ gr.Examples(EXAMPLES, [chat_input])
247
+
248
+ if __name__ == "__main__":
249
+
250
+ demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )