Hjgugugjhuhjggg commited on
Commit
a306750
·
verified ·
1 Parent(s): 7bcbbce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from llama_cpp import Llama
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ import re
5
+ import httpx
6
+ import asyncio
7
+ import gradio as gr
8
+ import os
9
+ import gptcache
10
+ from dotenv import load_dotenv
11
+ from fastapi import FastAPI, Request
12
+ from fastapi.responses import JSONResponse
13
+ import uvicorn
14
+ from threading import Thread
15
+
16
+ load_dotenv()
17
+
18
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
19
+
20
+ global_data = {
21
+ 'models': {},
22
+ 'tokens': {
23
+ 'eos': 'eos_token',
24
+ 'pad': 'pad_token',
25
+ 'padding': 'padding_token',
26
+ 'unk': 'unk_token',
27
+ 'bos': 'bos_token',
28
+ 'sep': 'sep_token',
29
+ 'cls': 'cls_token',
30
+ 'mask': 'mask_token'
31
+ },
32
+ 'model_metadata': {},
33
+ 'max_tokens': {},
34
+ 'tokenizers': {},
35
+ 'model_params': {},
36
+ 'model_size': {},
37
+ 'model_ftype': {},
38
+ 'n_ctx_train': {},
39
+ 'n_embd': {},
40
+ 'n_layer': {},
41
+ 'n_head': {},
42
+ 'n_head_kv': {},
43
+ 'n_rot': {},
44
+ 'n_swa': {},
45
+ 'n_embd_head_k': {},
46
+ 'n_embd_head_v': {},
47
+ 'n_gqa': {},
48
+ 'n_embd_k_gqa': {},
49
+ 'n_embd_v_gqa': {},
50
+ 'f_norm_eps': {},
51
+ 'f_norm_rms_eps': {},
52
+ 'f_clamp_kqv': {},
53
+ 'f_max_alibi_bias': {},
54
+ 'f_logit_scale': {},
55
+ 'n_ff': {},
56
+ 'n_expert': {},
57
+ 'n_expert_used': {},
58
+ 'causal_attn': {},
59
+ 'pooling_type': {},
60
+ 'rope_type': {},
61
+ 'rope_scaling': {},
62
+ 'freq_base_train': {},
63
+ 'freq_scale_train': {},
64
+ 'n_ctx_orig_yarn': {},
65
+ 'rope_finetuned': {},
66
+ 'ssm_d_conv': {},
67
+ 'ssm_d_inner': {},
68
+ 'ssm_d_state': {},
69
+ 'ssm_dt_rank': {},
70
+ 'ssm_dt_b_c_rms': {},
71
+ 'vocab_type': {},
72
+ 'model_type': {}
73
+ }
74
+
75
+ model_configs = [
76
+ {
77
+ "repo_id": "Hjgugugjhuhjggg/testing_semifinal-Q2_K-GGUF",
78
+ "filename": "testing_semifinal-q2_k.gguf",
79
+ "name": "testing"
80
+ }
81
+ ]
82
+
83
+ class ModelManager:
84
+ def __init__(self):
85
+ self.models = {}
86
+
87
+ def load_model(self, model_config):
88
+ if model_config['name'] not in self.models:
89
+ try:
90
+ self.models[model_config['name']] = Llama.from_pretrained(
91
+ repo_id=model_config['repo_id'],
92
+ filename=model_config['filename'],
93
+ use_auth_token=HUGGINGFACE_TOKEN,
94
+ n_threads=8,
95
+ use_gpu=False
96
+ )
97
+ except Exception as e:
98
+ pass
99
+
100
+ def load_all_models(self):
101
+ with ThreadPoolExecutor() as executor:
102
+ for config in model_configs:
103
+ executor.submit(self.load_model, config)
104
+ return self.models
105
+
106
+
107
+ model_manager = ModelManager()
108
+
109
+
110
+ global_data['models'] = model_manager.load_all_models()
111
+
112
+ class ChatRequest(BaseModel):
113
+ message: str
114
+
115
+ def normalize_input(input_text):
116
+ return input_text.strip()
117
+
118
+ def remove_duplicates(text):
119
+ lines = text.split('\n')
120
+ unique_lines = []
121
+ seen_lines = set()
122
+ for line in lines:
123
+ if line not in seen_lines:
124
+ unique_lines.append(line)
125
+ seen_lines.add(line)
126
+ return '\n'.join(unique_lines)
127
+
128
+ def cache_response(func):
129
+ def wrapper(*args, **kwargs):
130
+ cache_key = f"{args}-{kwargs}"
131
+ if gptcache.get(cache_key):
132
+ return gptcache.get(cache_key)
133
+ response = func(*args, **kwargs)
134
+ gptcache.set(cache_key, response)
135
+ return response
136
+ return wrapper
137
+
138
+
139
+ @cache_response
140
+ def generate_model_response(model, inputs):
141
+ try:
142
+ response = model(inputs)
143
+
144
+ @cache_response
145
+ def generate_model_response(model, inputs):
146
+ try:
147
+ response = model(inputs)
148
+ return remove_duplicates(response['choices'][0]['text'])
149
+ except Exception as e:
150
+ return ""
151
+
152
+
153
+ def remove_repetitive_responses(responses):
154
+ unique_responses = {}
155
+ for response in responses:
156
+ if response['model'] not in unique_responses:
157
+ unique_responses[response['model']] = response['response']
158
+ return unique_responses
159
+
160
+ async def process_message(message):
161
+ inputs = normalize_input(message)
162
+ with ThreadPoolExecutor() as executor:
163
+ futures = [
164
+ executor.submit(generate_model_response, model, inputs)
165
+ for model in global_data['models'].values()
166
+ ]
167
+ responses = [
168
+ {'model': model_name, 'response': future.result()}
169
+ for model_name, future in zip(global_data['models'].keys(), as_completed(futures))
170
+ ]
171
+ unique_responses = remove_repetitive_responses(responses)
172
+ formatted_response = ""
173
+ for model, response in unique_responses.items():
174
+ formatted_response += f"**{model}:**\n{response}\n\n"
175
+ return formatted_response
176
+
177
+ app = FastAPI()
178
+
179
+ @app.post("/generate")
180
+ async def generate(request: ChatRequest):
181
+ response = await process_message(request.message)
182
+ return JSONResponse(content={"response": response})
183
+
184
+ def run_uvicorn():
185
+ uvicorn.run(app, host="0.0.0.0", port=7860)
186
+
187
+ iface = gr.Interface(
188
+ fn=process_message,
189
+ inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
190
+ outputs=gr.Markdown(),
191
+ title="Multi-Model LLM API (CPU Optimized)",
192
+ description="Enter a message and get responses from multiple LLMs using CPU."
193
+ )
194
+
195
+
196
+ def run_gradio():
197
+ iface.launch(server_port=7862, prevent_thread_lock=True)
198
+
199
+ Ejecutar servidores
200
+ if __name__ == "__main__":
201
+ Thread(target=run_uvicorn).start()
202
+ Thread(target=run_gradio).start()