mateoluksenberg commited on
Commit
afc04de
1 Parent(s): 1439cde

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ import os
7
+ from threading import Thread
8
+ from fastapi import FastAPI, UploadFile, File, Form
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from typing import Optional, List
12
+
13
+ import fitz # PyMuPDF
14
+ import docx
15
+ from pptx import Presentation
16
+
17
+ MODEL_LIST = ["nikravan/glm-4vq"]
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
+ MODEL_ID = MODEL_LIST[0]
20
+ MODEL_NAME = "GLM-4vq"
21
+
22
+ TITLE = "<h1>AI CHAT DOCS</h1>"
23
+
24
+ DESCRIPTION = f"""
25
+ <center>
26
+ <p>
27
+ <br>
28
+ USANDO MODELO: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
29
+ </center>"""
30
+
31
+ CSS = """
32
+ h1 {
33
+ text-align: center;
34
+ display: block;
35
+ }
36
+ """
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
39
+
40
+ def extract_text(path):
41
+ return open(path, 'r').read()
42
+
43
+ def extract_pdf(path):
44
+ doc = fitz.open(path)
45
+ text = ""
46
+ for page in doc:
47
+ text += page.get_text()
48
+ return text
49
+
50
+ def extract_docx(path):
51
+ doc = docx.Document(path)
52
+ data = []
53
+ for paragraph in doc.paragraphs:
54
+ data.append(paragraph.text)
55
+ content = '\n\n'.join(data)
56
+ return content
57
+
58
+ def extract_pptx(path):
59
+ prs = Presentation(path)
60
+ text = ""
61
+ for slide in prs.slides:
62
+ for shape in slide.shapes:
63
+ if hasattr(shape, "text"):
64
+ text += shape.text + "\n"
65
+ return text
66
+
67
+ def mode_load(path):
68
+ choice = ""
69
+ file_type = path.split(".")[-1]
70
+ print(file_type)
71
+ if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
72
+ if file_type.endswith("pdf"):
73
+ content = extract_pdf(path)
74
+ elif file_type.endswith("docx"):
75
+ content = extract_docx(path)
76
+ elif file_type.endswith("pptx"):
77
+ content = extract_pptx(path)
78
+ else:
79
+ content = extract_text(path)
80
+ choice = "doc"
81
+ print(content[:100])
82
+ return choice, content[:5000]
83
+
84
+ elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
85
+ content = Image.open(path).convert('RGB')
86
+ choice = "image"
87
+ return choice, content
88
+
89
+ else:
90
+ raise gr.Error("Oops, unsupported files.")
91
+
92
+ @spaces.GPU()
93
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ MODEL_ID,
97
+ torch_dtype=torch.bfloat16,
98
+ low_cpu_mem_usage=True,
99
+ trust_remote_code=True
100
+ )
101
+
102
+ print(f'message is - {message}')
103
+ print(f'history is - {history}')
104
+ conversation = []
105
+ prompt_files = []
106
+ if message["files"]:
107
+ choice, contents = mode_load(message["files"][-1])
108
+ if choice == "image":
109
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
110
+ elif choice == "doc":
111
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
112
+ conversation.append({"role": "user", "content": format_msg})
113
+ else:
114
+ if len(history) == 0:
115
+ contents = None
116
+ conversation.append({"role": "user", "content": message['text']})
117
+ else:
118
+ for prompt, answer in history:
119
+ if answer is None:
120
+ prompt_files.append(prompt[0])
121
+ conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
122
+ else:
123
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
124
+ if len(prompt_files) > 0:
125
+ choice, contents = mode_load(prompt_files[-1])
126
+ else:
127
+ choice = ""
128
+ conversation.append({"role": "user", "image": "", "content": message['text']})
129
+
130
+ if choice == "image":
131
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
132
+ elif choice == "doc":
133
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
134
+ conversation.append({"role": "user", "content": format_msg})
135
+ print(f"Conversation is -\n{conversation}")
136
+
137
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
138
+ return_tensors="pt", return_dict=True).to(model.device)
139
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
140
+
141
+ generate_kwargs = dict(
142
+ max_length=max_length,
143
+ streamer=streamer,
144
+ do_sample=True,
145
+ top_p=top_p,
146
+ top_k=top_k,
147
+ temperature=temperature,
148
+ repetition_penalty=penalty,
149
+ eos_token_id=[151329, 151336, 151338],
150
+ )
151
+ gen_kwargs = {**input_ids, **generate_kwargs}
152
+
153
+ with torch.no_grad():
154
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
155
+ thread.start()
156
+ buffer = ""
157
+ for new_text in streamer:
158
+ buffer += new_text
159
+ yield buffer
160
+
161
+ chatbot = gr.Chatbot()
162
+ chat_input = gr.MultimodalTextbox(
163
+ interactive=True,
164
+ placeholder="Enter message or upload a file ...",
165
+ show_label=False,
166
+ )
167
+
168
+ EXAMPLES = [
169
+ [{"text": "Resumir Documento"}],
170
+ [{"text": "Explicar la Imagen"}],
171
+ [{"text": "¿De qué es la foto?", "files": ["perro.jpg"]}],
172
+ [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
173
+ ]
174
+
175
+ app = FastAPI()
176
+ app.add_middleware(
177
+ CORSMiddleware,
178
+ allow_origins=["*"],
179
+ allow_credentials=True,
180
+ allow_methods=["*"],
181
+ allow_headers=["*"],
182
+ )
183
+
184
+ class ChatMessage(BaseModel):
185
+ text: str
186
+ history: Optional[List] = []
187
+ temperature: float = 0.8
188
+ max_length: int = 4096
189
+ top_p: float = 1.0
190
+ top_k: int = 10
191
+ penalty: float = 1.0
192
+
193
+
194
+ @app.post("/test/")
195
+ async def test_endpoint(message: dict):
196
+ if "text" not in message:
197
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
198
+
199
+ response = {"message": f"Received your message: {message['text']}"}
200
+ return response
201
+
202
+ @app.post("/chat/")
203
+ async def chat_endpoint(message: ChatMessage, file: Optional[UploadFile] = None):
204
+ conversation = []
205
+ if file:
206
+ path = f"/tmp/{file.filename}"
207
+ with open(path, "wb") as f:
208
+ f.write(await file.read())
209
+ choice, contents = mode_load(path)
210
+ if choice == "image":
211
+ conversation.append({"role": "user", "image": contents, "content": message.text})
212
+ elif choice == "doc":
213
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message.text
214
+ conversation.append({"role": "user", "content": format_msg})
215
+ else:
216
+ conversation.append({"role": "user", "content": message.text})
217
+
218
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
219
+ return_tensors="pt", return_dict=True).to(model.device)
220
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
221
+
222
+ generate_kwargs = dict(
223
+ max_length=message.max_length,
224
+ streamer=streamer,
225
+ do_sample=True,
226
+ top_p=message.top_p,
227
+ top_k=message.top_k,
228
+ temperature=message.temperature,
229
+ repetition_penalty=message.penalty,
230
+ eos_token_id=[151329, 151336, 151338],
231
+ )
232
+ gen_kwargs = {**input_ids, **generate_kwargs}
233
+
234
+ with torch.no_grad():
235
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
236
+ thread.start()
237
+ buffer = ""
238
+ for new_text in streamer:
239
+ buffer += new_text
240
+ return {"response": buffer}
241
+
242
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
243
+ gr.HTML(TITLE)
244
+ gr.HTML(DESCRIPTION)
245
+ gr.ChatInterface(
246
+ fn=stream_chat,
247
+ multimodal=True,
248
+ textbox=chat_input,
249
+ chatbot=chatbot,
250
+ fill_height=True,
251
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
252
+ additional_inputs=[
253
+ gr.Slider(
254
+ minimum=0,
255
+ maximum=1,
256
+ step=0.1,
257
+ value=0.8,
258
+ label="Temperature",
259
+ render=False,
260
+ ),
261
+ gr.Slider(
262
+ minimum=1024,
263
+ maximum=8192,
264
+ step=1,
265
+ value=4096,
266
+ label="Max Length",
267
+ render=False,
268
+ ),
269
+ gr.Slider(
270
+ minimum=0.0,
271
+ maximum=1.0,
272
+ step=0.1,
273
+ value=1.0,
274
+ label="top_p",
275
+ render=False,
276
+ ),
277
+ gr.Slider(
278
+ minimum=1,
279
+ maximum=20,
280
+ step=1,
281
+ value=10,
282
+ label="top_k",
283
+ render=False,
284
+ ),
285
+ gr.Slider(
286
+ minimum=0.0,
287
+ maximum=2.0,
288
+ step=0.1,
289
+ value=1.0,
290
+ label="Repetition penalty",
291
+ render=False,
292
+ ),
293
+ ],
294
+ ),
295
+ gr.Examples(EXAMPLES, [chat_input])
296
+
297
+ if __name__ == "__main__":
298
+ demo.queue(api_open=False).launch(show_api=False, share=False)
299
+ import uvicorn
300
+ uvicorn.run(app, host="0.0.0.0", port=8000)