Daniel Marques commited on
Commit
0d6b303
·
1 Parent(s): 61d38da

feat: add ministral model

Browse files
constants.py CHANGED
@@ -32,8 +32,8 @@ CHROMA_SETTINGS = Settings(
32
  )
33
 
34
  # Context Window and Max New Tokens
35
- CONTEXT_WINDOW_SIZE = 3000
36
- MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE # int(CONTEXT_WINDOW_SIZE/4)
37
 
38
  #### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
39
 
 
32
  )
33
 
34
  # Context Window and Max New Tokens
35
+ CONTEXT_WINDOW_SIZE = 4096
36
+ MAX_NEW_TOKENS = 1024 # int(CONTEXT_WINDOW_SIZE/4)
37
 
38
  #### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
39
 
redis-implements/main.py DELETED
@@ -1,258 +0,0 @@
1
- from typing import Any, Dict, List, Union
2
-
3
- import os
4
- import glob
5
- import shutil
6
- import subprocess
7
- import redis
8
- import torch
9
- import concurrent.futures
10
- import json
11
-
12
- from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
13
- from fastapi.staticfiles import StaticFiles
14
-
15
- from pydantic import BaseModel
16
-
17
- # langchain
18
- from langchain.chains import RetrievalQA
19
- from langchain.embeddings import HuggingFaceInstructEmbeddings
20
- from langchain.callbacks.base import BaseCallbackHandler
21
- from langchain.schema import LLMResult
22
- from langchain.vectorstores import Chroma
23
-
24
- from prompt_template_utils import get_prompt_template
25
- from load_models import load_model
26
-
27
- from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY, SHOW_SOURCES
28
-
29
- class Predict(BaseModel):
30
- prompt: str
31
-
32
- class Delete(BaseModel):
33
- filename: str
34
-
35
- if torch.backends.mps.is_available():
36
- DEVICE_TYPE = "mps"
37
- elif torch.cuda.is_available():
38
- DEVICE_TYPE = "cuda"
39
- else:
40
- DEVICE_TYPE = "cpu"
41
-
42
- EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
43
- DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
44
- RETRIEVER = DB.as_retriever()
45
-
46
- redisClient = redis.Redis(host='localhost', port=6379, db=0)
47
-
48
- class MyCustomSyncHandler(BaseCallbackHandler):
49
- def __init__(self, redisClient):
50
- self.message = ''
51
- self.redisClient = redisClient
52
-
53
- def on_llm_new_token(self, token: str, **kwargs) -> Any:
54
- self.message += token
55
- self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
56
-
57
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
58
- print("on_llm_end end")
59
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
60
-
61
- def on_llm_error(
62
- self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
63
- ) -> Any:
64
- print("on_llm_error end")
65
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
66
-
67
- def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
68
- print("on_chain_end end")
69
- self.redisClient.publish(f'{kwargs["tags"][0]}', 'end')
70
-
71
- handleCallback = MyCustomSyncHandler(redisClient)
72
-
73
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handleCallback])
74
-
75
- prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
76
-
77
- QA = RetrievalQA.from_chain_type(
78
- llm=LLM,
79
- chain_type="stuff",
80
- retriever=RETRIEVER,
81
- return_source_documents=SHOW_SOURCES,
82
- chain_type_kwargs={
83
- "prompt": prompt,
84
- "memory": memory
85
- },
86
- )
87
-
88
- app = FastAPI(title="homepage-app")
89
- api_app = FastAPI(title="api app")
90
-
91
- app.mount("/api", api_app, name="api")
92
- app.mount("/", StaticFiles(directory="static",html = True), name="static")
93
-
94
- @api_app.get("/training")
95
- def run_ingest_route():
96
- global DB
97
- global RETRIEVER
98
- global QA
99
-
100
- try:
101
- if os.path.exists(PERSIST_DIRECTORY):
102
- try:
103
- shutil.rmtree(PERSIST_DIRECTORY)
104
- except OSError as e:
105
- raise HTTPException(status_code=500, detail=f"Error: {e.filename} - {e.strerror}.")
106
- else:
107
- raise HTTPException(status_code=500, detail="The directory does not exist")
108
-
109
- run_langest_commands = ["python", "ingest.py"]
110
-
111
- if DEVICE_TYPE == "cpu":
112
- run_langest_commands.append("--device_type")
113
- run_langest_commands.append(DEVICE_TYPE)
114
-
115
- result = subprocess.run(run_langest_commands, capture_output=True)
116
-
117
- if result.returncode != 0:
118
- raise HTTPException(status_code=400, detail="Script execution failed: {}")
119
-
120
- # load the vectorstore
121
- DB = Chroma(
122
- persist_directory=PERSIST_DIRECTORY,
123
- embedding_function=EMBEDDINGS,
124
- client_settings=CHROMA_SETTINGS,
125
- )
126
-
127
- RETRIEVER = DB.as_retriever()
128
-
129
- QA = RetrievalQA.from_chain_type(
130
- llm=LLM,
131
- chain_type="stuff",
132
- retriever=RETRIEVER,
133
- return_source_documents=SHOW_SOURCES,
134
- chain_type_kwargs={
135
- "prompt": prompt,
136
- "memory": memory
137
- },
138
- )
139
-
140
- return {"response": "The training was successfully completed"}
141
- except Exception as e:
142
- raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
143
-
144
- @api_app.get("/api/files")
145
- def get_files():
146
- upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
147
- files = glob.glob(os.path.join(upload_dir, '*'))
148
-
149
- return {"directory": upload_dir, "files": files}
150
-
151
- @api_app.delete("/api/delete_document")
152
- def delete_source_route(data: Delete):
153
- filename = data.filename
154
- path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
155
- file_to_delete = f"{path_source_documents}/{filename}"
156
-
157
- if os.path.exists(file_to_delete):
158
- try:
159
- os.remove(file_to_delete)
160
- print(f"{file_to_delete} has been deleted.")
161
-
162
- return {"message": f"{file_to_delete} has been deleted."}
163
- except OSError as e:
164
- raise HTTPException(status_code=400, detail=print(f"error: {e}."))
165
- else:
166
- raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist."))
167
-
168
- @api_app.post('/predict')
169
- async def predict(data: Predict):
170
- global QA
171
- user_prompt = data.prompt
172
- if user_prompt:
173
- res = QA(user_prompt)
174
-
175
- answer, docs = res["result"], res["source_documents"]
176
-
177
- prompt_response_dict = {
178
- "Prompt": user_prompt,
179
- "Answer": answer,
180
- }
181
-
182
- prompt_response_dict["Sources"] = []
183
- for document in docs:
184
- prompt_response_dict["Sources"].append(
185
- (os.path.basename(str(document.metadata["source"])), str(document.page_content))
186
- )
187
-
188
- return {"response": prompt_response_dict}
189
- else:
190
- raise HTTPException(status_code=400, detail="Prompt Incorrect")
191
-
192
- @api_app.post("/save_document/")
193
- async def create_upload_file(file: UploadFile):
194
- # Get the file size (in bytes)
195
- file.file.seek(0, 2)
196
- file_size = file.file.tell()
197
-
198
- # move the cursor back to the beginning
199
- await file.seek(0)
200
-
201
- if file_size > 10 * 1024 * 1024:
202
- # more than 10 MB
203
- raise HTTPException(status_code=400, detail="File too large")
204
-
205
- content_type = file.content_type
206
-
207
- if content_type not in [
208
- "text/plain",
209
- "text/markdown",
210
- "text/x-markdown",
211
- "text/csv",
212
- "application/msword",
213
- "application/pdf",
214
- "application/vnd.ms-excel",
215
- "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
216
- "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
217
- "text/x-python",
218
- "application/x-python-code"]:
219
- raise HTTPException(status_code=400, detail="Invalid file type")
220
-
221
- upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY)
222
- if not os.path.exists(upload_dir):
223
- os.makedirs(upload_dir)
224
-
225
- dest = os.path.join(upload_dir, file.filename)
226
-
227
- with open(dest, "wb") as buffer:
228
- shutil.copyfileobj(file.file, buffer)
229
-
230
- return {"filename": file.filename}
231
-
232
- @api_app.websocket("/ws/{client_id}")
233
- async def websocket_endpoint(websocket: WebSocket, client_id: int):
234
- global QA
235
-
236
- await websocket.accept()
237
-
238
- try:
239
- while True:
240
- prompt = await websocket.receive_text()
241
- pubsub = redisClient.pubsub()
242
- pubsub.subscribe(f'{client_id}')
243
-
244
- with concurrent.futures.ThreadPoolExecutor() as executor:
245
- executor.submit(QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True, callbacks=[handleCallback]))
246
-
247
- for item in pubsub.listen():
248
- if item["type"] == "message":
249
- message = item["data"].decode('utf-8')
250
- if message == "end": pubsub.unsubscribe({client_id})
251
- await websocket.send_text(f'{message}')
252
-
253
-
254
-
255
- except WebSocketDisconnect:
256
- print('disconnect')
257
- except RuntimeError as error:
258
- print(error)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
websocket/socketManager.py CHANGED
@@ -146,10 +146,3 @@ class WebSocketManager:
146
  data = message['data'].decode('utf-8')
147
  await socket.send_text(data)
148
 
149
- async def get_instance_qa(self, room_id: str, QA: Any):
150
- if room_id in self.qa:
151
- return self.qa[room_id]
152
-
153
- self.qa[room_id] = QA
154
- return self.qa[room_id]
155
-
 
146
  data = message['data'].decode('utf-8')
147
  await socket.send_text(data)
148