khulnasoft commited on
Commit
c107db2
1 Parent(s): 91837e5

Create awesome_chat.py

Browse files
Files changed (1) hide show
  1. awesome_chat.py +939 -0
awesome_chat.py ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import datetime
4
+ from io import BytesIO
5
+ import io
6
+ import os
7
+ import random
8
+ import time
9
+ import traceback
10
+ import uuid
11
+ import requests
12
+ import re
13
+ import json
14
+ import logging
15
+ import argparse
16
+ import yaml
17
+ from PIL import Image, ImageDraw
18
+ from diffusers.utils import load_image
19
+ from pydub import AudioSegment
20
+ import threading
21
+ from queue import Queue
22
+ from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
23
+ from huggingface_hub.inference_api import InferenceApi
24
+ from huggingface_hub.inference_api import ALL_TASKS
25
+ from models_server import models, status
26
+ from functools import partial
27
+ from huggingface_hub import Repository
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--config", type=str, default="config.yaml.dev")
31
+ parser.add_argument("--mode", type=str, default="cli")
32
+ args = parser.parse_args()
33
+
34
+ if __name__ != "__main__":
35
+ args.config = "config.gradio.yaml"
36
+
37
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
38
+
39
+ if not os.path.exists("logs"):
40
+ os.mkdir("logs")
41
+
42
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
43
+
44
+ DATASET_REPO_URL = "https://huggingface.co/datasets/deepcode-ai/HuggingSpace_logs"
45
+ LOG_HF_TOKEN = os.environ.get("LOG_HF_TOKEN")
46
+ if LOG_HF_TOKEN:
47
+ repo = Repository(
48
+ local_dir="logs", clone_from=DATASET_REPO_URL, use_auth_token=LOG_HF_TOKEN
49
+ )
50
+
51
+ logger = logging.getLogger(__name__)
52
+ logger.setLevel(logging.INFO)
53
+ logger.handlers = []
54
+ logger.propagate = False
55
+
56
+ handler = logging.StreamHandler()
57
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
58
+ handler.setFormatter(formatter)
59
+ if config["debug"]:
60
+ handler.setLevel(logging.DEBUG)
61
+ logger.addHandler(handler)
62
+
63
+ log_file = config["log_file"]
64
+ if log_file:
65
+ log_file = log_file.replace("TIMESTAMP", now)
66
+ filehandler = logging.FileHandler(log_file)
67
+ filehandler.setLevel(logging.DEBUG)
68
+ filehandler.setFormatter(formatter)
69
+ logger.addHandler(filehandler)
70
+
71
+ LLM = config["model"]
72
+ use_completion = config["use_completion"]
73
+
74
+ # consistent: wrong msra model name
75
+ LLM_encoding = LLM
76
+ if LLM == "gpt-3.5-turbo":
77
+ LLM_encoding = "text-davinci-003"
78
+ task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
79
+ choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
80
+
81
+ # ENDPOINT MODEL NAME
82
+ # /v1/chat/completions gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
83
+ # /v1/completions text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
84
+
85
+ if use_completion:
86
+ api_name = "completions"
87
+ else:
88
+ api_name = "chat/completions"
89
+
90
+ if not config["dev"]:
91
+ if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
92
+ raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
93
+ OPENAI_KEY = config["openai"]["key"]
94
+ endpoint = f"https://api.openai.com/v1/{api_name}"
95
+ if OPENAI_KEY.startswith("sk-"):
96
+ HEADER = {
97
+ "Authorization": f"Bearer {OPENAI_KEY}"
98
+ }
99
+ else:
100
+ HEADER = None
101
+ else:
102
+ endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
103
+ HEADER = None
104
+
105
+ PROXY = None
106
+ if config["proxy"]:
107
+ PROXY = {
108
+ "https": config["proxy"],
109
+ }
110
+
111
+ inference_mode = config["inference_mode"]
112
+
113
+ parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
114
+ choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
115
+ response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
116
+
117
+ parse_task_prompt = config["prompt"]["parse_task"]
118
+ choose_model_prompt = config["prompt"]["choose_model"]
119
+ response_results_prompt = config["prompt"]["response_results"]
120
+
121
+ parse_task_tprompt = config["tprompt"]["parse_task"]
122
+ choose_model_tprompt = config["tprompt"]["choose_model"]
123
+ response_results_tprompt = config["tprompt"]["response_results"]
124
+
125
+ MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
126
+ MODELS_MAP = {}
127
+ for model in MODELS:
128
+ tag = model["task"]
129
+ if tag not in MODELS_MAP:
130
+ MODELS_MAP[tag] = []
131
+ MODELS_MAP[tag].append(model)
132
+ METADATAS = {}
133
+ for model in MODELS:
134
+ METADATAS[model["id"]] = model
135
+
136
+ def convert_chat_to_completion(data):
137
+ messages = data.pop('messages', [])
138
+ tprompt = ""
139
+ if messages[0]['role'] == "system":
140
+ tprompt = messages[0]['content']
141
+ messages = messages[1:]
142
+ final_prompt = ""
143
+ for message in messages:
144
+ if message['role'] == "user":
145
+ final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
146
+ elif message['role'] == "assistant":
147
+ final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
148
+ else:
149
+ final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
150
+ final_prompt = tprompt + final_prompt
151
+ final_prompt = final_prompt + "<im_start>assistant"
152
+ data["prompt"] = final_prompt
153
+ data['stop'] = data.get('stop', ["<im_end>"])
154
+ data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
155
+ return data
156
+
157
+ def send_request(data):
158
+ global HEADER
159
+ openaikey = data.pop("openaikey")
160
+ if use_completion:
161
+ data = convert_chat_to_completion(data)
162
+ if openaikey and openaikey.startswith("sk-"):
163
+ HEADER = {
164
+ "Authorization": f"Bearer {openaikey}"
165
+ }
166
+
167
+ response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
168
+ logger.debug(response.text.strip())
169
+ if "choices" not in response.json():
170
+ return response.json()
171
+ if use_completion:
172
+ return response.json()["choices"][0]["text"].strip()
173
+ else:
174
+ return response.json()["choices"][0]["message"]["content"].strip()
175
+
176
+ def replace_slot(text, entries):
177
+ for key, value in entries.items():
178
+ if not isinstance(value, str):
179
+ value = str(value)
180
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', "").replace('\\', '\\\\'))
181
+ return text
182
+
183
+ def find_json(s):
184
+ s = s.replace("\'", "\"")
185
+ start = s.find("{")
186
+ end = s.rfind("}")
187
+ res = s[start:end+1]
188
+ res = res.replace("\n", "")
189
+ return res
190
+
191
+ def field_extract(s, field):
192
+ try:
193
+ field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
194
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
195
+ except:
196
+ field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
197
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
198
+ return extracted
199
+
200
+ def get_id_reason(choose_str):
201
+ reason = field_extract(choose_str, "reason")
202
+ id = field_extract(choose_str, "id")
203
+ choose = {"id": id, "reason": reason}
204
+ return id.strip(), reason.strip(), choose
205
+
206
+ def record_case(success, **args):
207
+ if not success:
208
+ return
209
+ f = open(f"logs/log_success_{now}.jsonl", "a")
210
+ log = args
211
+ f.write(json.dumps(log) + "\n")
212
+ f.close()
213
+ if LOG_HF_TOKEN:
214
+ commit_url = repo.push_to_hub(blocking=False)
215
+
216
+ def image_to_bytes(img_url):
217
+ img_byte = io.BytesIO()
218
+ type = img_url.split(".")[-1]
219
+ load_image(img_url).save(img_byte, format="png")
220
+ img_data = img_byte.getvalue()
221
+ return img_data
222
+
223
+ def resource_has_dep(command):
224
+ args = command["args"]
225
+ for _, v in args.items():
226
+ if "<GENERATED>" in v:
227
+ return True
228
+ return False
229
+
230
+ def fix_dep(tasks):
231
+ for task in tasks:
232
+ args = task["args"]
233
+ task["dep"] = []
234
+ for k, v in args.items():
235
+ if "<GENERATED>" in v:
236
+ dep_task_id = int(v.split("-")[1])
237
+ if dep_task_id not in task["dep"]:
238
+ task["dep"].append(dep_task_id)
239
+ if len(task["dep"]) == 0:
240
+ task["dep"] = [-1]
241
+ return tasks
242
+
243
+ def unfold(tasks):
244
+ flag_unfold_task = False
245
+ try:
246
+ for task in tasks:
247
+ for key, value in task["args"].items():
248
+ if "<GENERATED>" in value:
249
+ generated_items = value.split(",")
250
+ if len(generated_items) > 1:
251
+ flag_unfold_task = True
252
+ for item in generated_items:
253
+ new_task = copy.deepcopy(task)
254
+ dep_task_id = int(item.split("-")[1])
255
+ new_task["dep"] = [dep_task_id]
256
+ new_task["args"][key] = item
257
+ tasks.append(new_task)
258
+ tasks.remove(task)
259
+ except Exception as e:
260
+ print(e)
261
+ traceback.print_exc()
262
+ logger.debug("unfold task failed.")
263
+
264
+ if flag_unfold_task:
265
+ logger.debug(f"unfold tasks: {tasks}")
266
+
267
+ return tasks
268
+
269
+ def chitchat(messages, openaikey=None):
270
+ data = {
271
+ "model": LLM,
272
+ "messages": messages,
273
+ "openaikey": openaikey
274
+ }
275
+ return send_request(data)
276
+
277
+ def parse_task(context, input, openaikey=None):
278
+ demos_or_presteps = parse_task_demos_or_presteps
279
+ messages = json.loads(demos_or_presteps)
280
+ for message in messages:
281
+ if not isinstance(message["content"], str):
282
+ message["content"] = json.dumps(message["content"], ensure_ascii=False)
283
+ messages.insert(0, {"role": "system", "content": parse_task_tprompt})
284
+
285
+ # cut chat logs
286
+ start = 0
287
+ while start <= len(context):
288
+ history = context[start:]
289
+ prompt = replace_slot(parse_task_prompt, {
290
+ "input": input,
291
+ "context": history
292
+ })
293
+ messages.append({"role": "user", "content": prompt})
294
+ history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
295
+ num = count_tokens(LLM_encoding, history_text)
296
+ if get_max_context_length(LLM) - num > 800:
297
+ break
298
+ messages.pop()
299
+ start += 2
300
+
301
+ logger.debug(messages)
302
+ data = {
303
+ "model": LLM,
304
+ "messages": messages,
305
+ "temperature": 0,
306
+ "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
307
+ "openaikey": openaikey
308
+ }
309
+ return send_request(data)
310
+
311
+ def choose_model(input, task, metas, openaikey = None):
312
+ prompt = replace_slot(choose_model_prompt, {
313
+ "input": input,
314
+ "task": task,
315
+ "metas": metas,
316
+ })
317
+ demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
318
+ "input": input,
319
+ "task": task,
320
+ "metas": metas
321
+ })
322
+ messages = json.loads(demos_or_presteps)
323
+ messages.insert(0, {"role": "system", "content": choose_model_tprompt})
324
+ messages.append({"role": "user", "content": prompt})
325
+ logger.debug(messages)
326
+ data = {
327
+ "model": LLM,
328
+ "messages": messages,
329
+ "temperature": 0,
330
+ "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
331
+ "openaikey": openaikey
332
+ }
333
+ return send_request(data)
334
+
335
+
336
+ def response_results(input, results, openaikey=None):
337
+ results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
338
+ prompt = replace_slot(response_results_prompt, {
339
+ "input": input,
340
+ })
341
+ demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
342
+ "input": input,
343
+ "processes": results
344
+ })
345
+ messages = json.loads(demos_or_presteps, strict=False)
346
+ messages.insert(0, {"role": "system", "content": response_results_tprompt})
347
+ messages.append({"role": "user", "content": prompt})
348
+ logger.debug(messages)
349
+ data = {
350
+ "model": LLM,
351
+ "messages": messages,
352
+ "temperature": 0,
353
+ "openaikey": openaikey
354
+ }
355
+ return send_request(data)
356
+
357
+ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
358
+ if huggingfacetoken is None:
359
+ HUGGINGFACE_HEADERS = {}
360
+ else:
361
+ HUGGINGFACE_HEADERS = {
362
+ "Authorization": f"Bearer {huggingfacetoken}",
363
+ }
364
+ task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
365
+ inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
366
+
367
+ # NLP tasks
368
+ if task == "question-answering":
369
+ inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
370
+ result = inference(inputs)
371
+ if task == "sentence-similarity":
372
+ inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
373
+ result = inference(inputs)
374
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
375
+ inputs = data["text"]
376
+ result = inference(inputs)
377
+
378
+ # CV tasks
379
+ if task == "visual-question-answering" or task == "document-question-answering":
380
+ img_url = data["image"]
381
+ text = data["text"]
382
+ img_data = image_to_bytes(img_url)
383
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
384
+ json_data = {}
385
+ json_data["inputs"] = {}
386
+ json_data["inputs"]["question"] = text
387
+ json_data["inputs"]["image"] = img_base64
388
+ result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
389
+ # result = inference(inputs) # not support
390
+
391
+ if task == "image-to-image":
392
+ img_url = data["image"]
393
+ img_data = image_to_bytes(img_url)
394
+ # result = inference(data=img_data) # not support
395
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
396
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
397
+ result = r.json()
398
+ if "path" in result:
399
+ result["generated image"] = result.pop("path")
400
+
401
+ if task == "text-to-image":
402
+ inputs = data["text"]
403
+ img = inference(inputs)
404
+ name = str(uuid.uuid4())[:4]
405
+ img.save(f"public/images/{name}.png")
406
+ result = {}
407
+ result["generated image"] = f"/images/{name}.png"
408
+
409
+ if task == "image-segmentation":
410
+ img_url = data["image"]
411
+ img_data = image_to_bytes(img_url)
412
+ image = Image.open(BytesIO(img_data))
413
+ predicted = inference(data=img_data)
414
+ colors = []
415
+ for i in range(len(predicted)):
416
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
417
+ for i, pred in enumerate(predicted):
418
+ label = pred["label"]
419
+ mask = pred.pop("mask").encode("utf-8")
420
+ mask = base64.b64decode(mask)
421
+ mask = Image.open(BytesIO(mask), mode='r')
422
+ mask = mask.convert('L')
423
+
424
+ layer = Image.new('RGBA', mask.size, colors[i])
425
+ image.paste(layer, (0, 0), mask)
426
+ name = str(uuid.uuid4())[:4]
427
+ image.save(f"public/images/{name}.jpg")
428
+ result = {}
429
+ result["generated image with segmentation mask"] = f"/images/{name}.jpg"
430
+ result["predicted"] = predicted
431
+
432
+ if task == "object-detection":
433
+ img_url = data["image"]
434
+ img_data = image_to_bytes(img_url)
435
+ predicted = inference(data=img_data)
436
+ image = Image.open(BytesIO(img_data))
437
+ draw = ImageDraw.Draw(image)
438
+ labels = list(item['label'] for item in predicted)
439
+ color_map = {}
440
+ for label in labels:
441
+ if label not in color_map:
442
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
443
+ for label in predicted:
444
+ box = label["box"]
445
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
446
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
447
+ name = str(uuid.uuid4())[:4]
448
+ image.save(f"public/images/{name}.jpg")
449
+ result = {}
450
+ result["generated image with predicted box"] = f"/images/{name}.jpg"
451
+ result["predicted"] = predicted
452
+
453
+ if task in ["image-classification"]:
454
+ img_url = data["image"]
455
+ img_data = image_to_bytes(img_url)
456
+ result = inference(data=img_data)
457
+
458
+ if task == "image-to-text":
459
+ img_url = data["image"]
460
+ img_data = image_to_bytes(img_url)
461
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
462
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
463
+ result = {}
464
+ if "generated_text" in r.json()[0]:
465
+ result["generated text"] = r.json()[0].pop("generated_text")
466
+
467
+ # AUDIO tasks
468
+ if task == "text-to-speech":
469
+ inputs = data["text"]
470
+ response = inference(inputs, raw_response=True)
471
+ # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
472
+ name = str(uuid.uuid4())[:4]
473
+ with open(f"public/audios/{name}.flac", "wb") as f:
474
+ f.write(response.content)
475
+ result = {"generated audio": f"/audios/{name}.flac"}
476
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
477
+ audio_url = data["audio"]
478
+ audio_data = requests.get(audio_url, timeout=10).content
479
+ response = inference(data=audio_data, raw_response=True)
480
+ result = response.json()
481
+ if task == "audio-to-audio":
482
+ content = None
483
+ type = None
484
+ for k, v in result[0].items():
485
+ if k == "blob":
486
+ content = base64.b64decode(v.encode("utf-8"))
487
+ if k == "content-type":
488
+ type = "audio/flac".split("/")[-1]
489
+ audio = AudioSegment.from_file(BytesIO(content))
490
+ name = str(uuid.uuid4())[:4]
491
+ audio.export(f"public/audios/{name}.{type}", format=type)
492
+ result = {"generated audio": f"/audios/{name}.{type}"}
493
+ return result
494
+
495
+ def local_model_inference(model_id, data, task):
496
+ inference = partial(models, model_id)
497
+ # contronlet
498
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
499
+ img_url = data["image"]
500
+ text = data["text"]
501
+ results = inference({"img_url": img_url, "text": text})
502
+ if "path" in results:
503
+ results["generated image"] = results.pop("path")
504
+ return results
505
+ if model_id.endswith("-control"):
506
+ img_url = data["image"]
507
+ results = inference({"img_url": img_url})
508
+ if "path" in results:
509
+ results["generated image"] = results.pop("path")
510
+ return results
511
+
512
+ if task == "text-to-video":
513
+ results = inference(data)
514
+ if "path" in results:
515
+ results["generated video"] = results.pop("path")
516
+ return results
517
+
518
+ # NLP tasks
519
+ if task == "question-answering" or task == "sentence-similarity":
520
+ results = inference(json=data)
521
+ return results
522
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
523
+ results = inference(json=data)
524
+ return results
525
+
526
+ # CV tasks
527
+ if task == "depth-estimation":
528
+ img_url = data["image"]
529
+ results = inference({"img_url": img_url})
530
+ if "path" in results:
531
+ results["generated depth image"] = results.pop("path")
532
+ return results
533
+ if task == "image-segmentation":
534
+ img_url = data["image"]
535
+ results = inference({"img_url": img_url})
536
+ results["generated image with segmentation mask"] = results.pop("path")
537
+ return results
538
+ if task == "image-to-image":
539
+ img_url = data["image"]
540
+ results = inference({"img_url": img_url})
541
+ if "path" in results:
542
+ results["generated image"] = results.pop("path")
543
+ return results
544
+ if task == "text-to-image":
545
+ results = inference(data)
546
+ if "path" in results:
547
+ results["generated image"] = results.pop("path")
548
+ return results
549
+ if task == "object-detection":
550
+ img_url = data["image"]
551
+ predicted = inference({"img_url": img_url})
552
+ if "error" in predicted:
553
+ return predicted
554
+ image = load_image(img_url)
555
+ draw = ImageDraw.Draw(image)
556
+ labels = list(item['label'] for item in predicted)
557
+ color_map = {}
558
+ for label in labels:
559
+ if label not in color_map:
560
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
561
+ for label in predicted:
562
+ box = label["box"]
563
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
564
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
565
+ name = str(uuid.uuid4())[:4]
566
+ image.save(f"public/images/{name}.jpg")
567
+ results = {}
568
+ results["generated image with predicted box"] = f"/images/{name}.jpg"
569
+ results["predicted"] = predicted
570
+ return results
571
+ if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
572
+ img_url = data["image"]
573
+ text = None
574
+ if "text" in data:
575
+ text = data["text"]
576
+ results = inference({"img_url": img_url, "text": text})
577
+ return results
578
+ # AUDIO tasks
579
+ if task == "text-to-speech":
580
+ results = inference(data)
581
+ if "path" in results:
582
+ results["generated audio"] = results.pop("path")
583
+ return results
584
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
585
+ audio_url = data["audio"]
586
+ results = inference({"audio_url": audio_url})
587
+ return results
588
+
589
+
590
+ def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
591
+ if huggingfacetoken:
592
+ HUGGINGFACE_HEADERS = {
593
+ "Authorization": f"Bearer {huggingfacetoken}",
594
+ }
595
+ else:
596
+ HUGGINGFACE_HEADERS = None
597
+ if hosted_on == "unknown":
598
+ r = status(model_id)
599
+ logger.debug("Local Server Status: " + str(r))
600
+ if "loaded" in r and r["loaded"]:
601
+ hosted_on = "local"
602
+ else:
603
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
604
+ r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
605
+ logger.debug("Huggingface Status: " + str(r.json()))
606
+ if "loaded" in r and r["loaded"]:
607
+ hosted_on = "huggingface"
608
+ try:
609
+ if hosted_on == "local":
610
+ inference_result = local_model_inference(model_id, data, task)
611
+ elif hosted_on == "huggingface":
612
+ inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
613
+ except Exception as e:
614
+ print(e)
615
+ traceback.print_exc()
616
+ inference_result = {"error":{"message": str(e)}}
617
+ return inference_result
618
+
619
+
620
+ def get_model_status(model_id, url, headers, queue = None):
621
+ endpoint_type = "huggingface" if "huggingface" in url else "local"
622
+ if "huggingface" in url:
623
+ r = requests.get(url, headers=headers, proxies=PROXY)
624
+ else:
625
+ r = status(model_id)
626
+ if "loaded" in r and r["loaded"]:
627
+ if queue:
628
+ queue.put((model_id, True, endpoint_type))
629
+ return True
630
+ else:
631
+ if queue:
632
+ queue.put((model_id, False, None))
633
+ return False
634
+
635
+ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
636
+ all_available_models = {"local": [], "huggingface": []}
637
+ threads = []
638
+ result_queue = Queue()
639
+ HUGGINGFACE_HEADERS = {
640
+ "Authorization": f"Bearer {huggingfacetoken}",
641
+ }
642
+ for candidate in candidates:
643
+ model_id = candidate["id"]
644
+
645
+ if inference_mode != "local":
646
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
647
+ thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
648
+ threads.append(thread)
649
+ thread.start()
650
+
651
+ if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
652
+ thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
653
+ threads.append(thread)
654
+ thread.start()
655
+
656
+ result_count = len(threads)
657
+ while result_count:
658
+ model_id, status, endpoint_type = result_queue.get()
659
+ if status and model_id not in all_available_models:
660
+ all_available_models[endpoint_type].append(model_id)
661
+ if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
662
+ break
663
+ result_count -= 1
664
+
665
+ for thread in threads:
666
+ thread.join()
667
+
668
+ return all_available_models
669
+
670
+ def collect_result(command, choose, inference_result):
671
+ result = {"task": command}
672
+ result["inference result"] = inference_result
673
+ result["choose model result"] = choose
674
+ logger.debug(f"inference result: {inference_result}")
675
+ return result
676
+
677
+
678
+ def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
679
+ id = command["id"]
680
+ args = command["args"]
681
+ task = command["task"]
682
+ deps = command["dep"]
683
+ if deps[0] != -1:
684
+ dep_tasks = [results[dep] for dep in deps]
685
+ else:
686
+ dep_tasks = []
687
+
688
+ logger.debug(f"Run task: {id} - {task}")
689
+ logger.debug("Deps: " + json.dumps(dep_tasks))
690
+
691
+ if deps[0] != -1:
692
+ if "image" in args and "<GENERATED>-" in args["image"]:
693
+ resource_id = int(args["image"].split("-")[1])
694
+ if "generated image" in results[resource_id]["inference result"]:
695
+ args["image"] = results[resource_id]["inference result"]["generated image"]
696
+ if "audio" in args and "<GENERATED>-" in args["audio"]:
697
+ resource_id = int(args["audio"].split("-")[1])
698
+ if "generated audio" in results[resource_id]["inference result"]:
699
+ args["audio"] = results[resource_id]["inference result"]["generated audio"]
700
+ if "text" in args and "<GENERATED>-" in args["text"]:
701
+ resource_id = int(args["text"].split("-")[1])
702
+ if "generated text" in results[resource_id]["inference result"]:
703
+ args["text"] = results[resource_id]["inference result"]["generated text"]
704
+
705
+ text = image = audio = None
706
+ for dep_task in dep_tasks:
707
+ if "generated text" in dep_task["inference result"]:
708
+ text = dep_task["inference result"]["generated text"]
709
+ logger.debug("Detect the generated text of dependency task (from results):" + text)
710
+ elif "text" in dep_task["task"]["args"]:
711
+ text = dep_task["task"]["args"]["text"]
712
+ logger.debug("Detect the text of dependency task (from args): " + text)
713
+ if "generated image" in dep_task["inference result"]:
714
+ image = dep_task["inference result"]["generated image"]
715
+ logger.debug("Detect the generated image of dependency task (from results): " + image)
716
+ elif "image" in dep_task["task"]["args"]:
717
+ image = dep_task["task"]["args"]["image"]
718
+ logger.debug("Detect the image of dependency task (from args): " + image)
719
+ if "generated audio" in dep_task["inference result"]:
720
+ audio = dep_task["inference result"]["generated audio"]
721
+ logger.debug("Detect the generated audio of dependency task (from results): " + audio)
722
+ elif "audio" in dep_task["task"]["args"]:
723
+ audio = dep_task["task"]["args"]["audio"]
724
+ logger.debug("Detect the audio of dependency task (from args): " + audio)
725
+
726
+ if "image" in args and "<GENERATED>" in args["image"]:
727
+ if image:
728
+ args["image"] = image
729
+ if "audio" in args and "<GENERATED>" in args["audio"]:
730
+ if audio:
731
+ args["audio"] = audio
732
+ if "text" in args and "<GENERATED>" in args["text"]:
733
+ if text:
734
+ args["text"] = text
735
+
736
+ for resource in ["image", "audio"]:
737
+ if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
738
+ args[resource] = f"public/{args[resource]}"
739
+
740
+ if "-text-to-image" in command['task'] and "text" not in args:
741
+ logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
742
+ control = task.split("-")[0]
743
+
744
+ if control == "seg":
745
+ task = "image-segmentation"
746
+ command['task'] = task
747
+ elif control == "depth":
748
+ task = "depth-estimation"
749
+ command['task'] = task
750
+ else:
751
+ task = f"{control}-control"
752
+
753
+ command["args"] = args
754
+ logger.debug(f"parsed task: {command}")
755
+
756
+ if task.endswith("-text-to-image") or task.endswith("-control"):
757
+ if inference_mode != "huggingface":
758
+ if task.endswith("-text-to-image"):
759
+ control = task.split("-")[0]
760
+ best_model_id = f"lllyasviel/sd-controlnet-{control}"
761
+ else:
762
+ best_model_id = task
763
+ hosted_on = "local"
764
+ reason = "ControlNet is the best model for this task."
765
+ choose = {"id": best_model_id, "reason": reason}
766
+ logger.debug(f"chosen model: {choose}")
767
+ else:
768
+ logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
769
+ record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
770
+ inference_result = {"error": f"service related to ControlNet is not available."}
771
+ results[id] = collect_result(command, "", inference_result)
772
+ return False
773
+ elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
774
+ best_model_id = "ChatGPT"
775
+ reason = "ChatGPT performs well on some NLP tasks as well."
776
+ choose = {"id": best_model_id, "reason": reason}
777
+ messages = [{
778
+ "role": "user",
779
+ "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
780
+ }]
781
+ response = chitchat(messages, openaikey)
782
+ results[id] = collect_result(command, choose, {"response": response})
783
+ return True
784
+ else:
785
+ if task not in MODELS_MAP:
786
+ logger.warning(f"no available models on {task} task.")
787
+ record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
788
+ inference_result = {"error": f"{command['task']} not found in available tasks."}
789
+ results[id] = collect_result(command, "", inference_result)
790
+ return False
791
+
792
+ candidates = MODELS_MAP[task][:20]
793
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
794
+ all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
795
+ logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
796
+
797
+ if len(all_avaliable_model_ids) == 0:
798
+ logger.warning(f"no available models on {command['task']}")
799
+ record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
800
+ inference_result = {"error": f"no available models on {command['task']} task."}
801
+ results[id] = collect_result(command, "", inference_result)
802
+ return False
803
+
804
+ all_avaliable_model_ids = all_avaliable_model_ids[:1]
805
+ if len(all_avaliable_model_ids) == 1:
806
+ best_model_id = all_avaliable_model_ids[0]
807
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
808
+ reason = "Only one model available."
809
+ choose = {"id": best_model_id, "reason": reason}
810
+ logger.debug(f"chosen model: {choose}")
811
+ else:
812
+ cand_models_info = [
813
+ {
814
+ "id": model["id"],
815
+ "inference endpoint": all_avaliable_models.get(
816
+ "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
817
+ ),
818
+ "likes": model.get("likes"),
819
+ "description": model.get("description", "")[:config["max_description_length"]],
820
+ "language": model.get("language"),
821
+ "tags": model.get("tags"),
822
+ }
823
+ for model in candidates
824
+ if model["id"] in all_avaliable_model_ids
825
+ ]
826
+
827
+ choose_str = choose_model(input, command, cand_models_info, openaikey)
828
+ logger.debug(f"chosen model: {choose_str}")
829
+ try:
830
+ choose = json.loads(choose_str)
831
+ reason = choose["reason"]
832
+ best_model_id = choose["id"]
833
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
834
+ except Exception as e:
835
+ logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
836
+ choose_str = find_json(choose_str)
837
+ best_model_id, reason, choose = get_id_reason(choose_str)
838
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
839
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
840
+
841
+ if "error" in inference_result:
842
+ logger.warning(f"Inference error: {inference_result['error']}")
843
+ record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
844
+ results[id] = collect_result(command, choose, inference_result)
845
+ return False
846
+
847
+ results[id] = collect_result(command, choose, inference_result)
848
+ return True
849
+
850
+ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
851
+ start = time.time()
852
+ context = messages[:-1]
853
+ input = messages[-1]["content"]
854
+ logger.info("*"*80)
855
+ logger.info(f"input: {input}")
856
+
857
+ task_str = parse_task(context, input, openaikey)
858
+ logger.info(task_str)
859
+
860
+ if "error" in task_str:
861
+ return str(task_str), {}
862
+ else:
863
+ task_str = task_str.strip()
864
+
865
+ try:
866
+ tasks = json.loads(task_str)
867
+ except Exception as e:
868
+ logger.debug(e)
869
+ response = chitchat(messages, openaikey)
870
+ record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
871
+ return response, {}
872
+
873
+ if task_str == "[]": # using LLM response for empty task
874
+ record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
875
+ response = chitchat(messages, openaikey)
876
+ return response, {}
877
+
878
+ if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
879
+ record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
880
+ response = chitchat(messages, openaikey)
881
+ best_model_id = "ChatGPT"
882
+ reason = "ChatGPT performs well on some NLP tasks as well."
883
+ choose = {"id": best_model_id, "reason": reason}
884
+ return response, collect_result(tasks[0], choose, {"response": response})
885
+
886
+
887
+ tasks = unfold(tasks)
888
+ tasks = fix_dep(tasks)
889
+ logger.debug(tasks)
890
+
891
+ if return_planning:
892
+ return tasks
893
+
894
+ results = {}
895
+ threads = []
896
+ tasks = tasks[:]
897
+ d = dict()
898
+ retry = 0
899
+ while True:
900
+ num_threads = len(threads)
901
+ for task in tasks:
902
+ dep = task["dep"]
903
+ # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
904
+ for dep_id in dep:
905
+ if dep_id >= task["id"]:
906
+ task["dep"] = [-1]
907
+ dep = [-1]
908
+ break
909
+ if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
910
+ tasks.remove(task)
911
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
912
+ thread.start()
913
+ threads.append(thread)
914
+ if num_threads == len(threads):
915
+ time.sleep(0.5)
916
+ retry += 1
917
+ if retry > 80:
918
+ logger.debug("User has waited too long, Loop break.")
919
+ break
920
+ if len(tasks) == 0:
921
+ break
922
+ for thread in threads:
923
+ thread.join()
924
+
925
+ results = d.copy()
926
+
927
+ logger.debug(results)
928
+ if return_results:
929
+ return results
930
+
931
+ response = response_results(input, results, openaikey).strip()
932
+
933
+ end = time.time()
934
+ during = end - start
935
+
936
+ answer = {"message": response}
937
+ record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
938
+ logger.info(f"response: {response}")
939
+ return response, results