sergey21000 commited on
Commit
e66c4c3
1 Parent(s): aeed07a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +412 -393
  2. utils.py +506 -500
app.py CHANGED
@@ -1,394 +1,413 @@
1
- from typing import List, Optional
2
-
3
- import gradio as gr
4
- from langchain_core.vectorstores import VectorStore
5
-
6
- from config import (
7
- LLM_MODEL_REPOS,
8
- EMBED_MODEL_REPOS,
9
- SUBTITLES_LANGUAGES,
10
- GENERATE_KWARGS,
11
- )
12
-
13
- from utils import (
14
- load_llm_model,
15
- load_embed_model,
16
- load_documents_and_create_db,
17
- user_message_to_chatbot,
18
- update_user_message_with_context,
19
- get_llm_response,
20
- get_gguf_model_names,
21
- add_new_model_repo,
22
- clear_llm_folder,
23
- clear_embed_folder,
24
- get_memory_usage,
25
- )
26
-
27
-
28
- # ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============
29
-
30
- def get_rag_settings(rag_mode: bool, render: bool = True):
31
- k = gr.Radio(
32
- choices=[1, 2, 3, 4, 5, 'all'],
33
- value=2,
34
- label='Number of relevant documents for search',
35
- visible=rag_mode,
36
- render=render,
37
- )
38
- score_threshold = gr.Slider(
39
- minimum=0,
40
- maximum=1,
41
- value=0.5,
42
- step=0.05,
43
- label='relevance_scores_threshold',
44
- visible=rag_mode,
45
- render=render,
46
- )
47
- return k, score_threshold
48
-
49
-
50
- def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component:
51
- num_lines = len(text.split('\n'))
52
- max_lines = 10
53
- num_lines = max_lines if num_lines > max_lines else num_lines
54
- return gr.Textbox(
55
- text,
56
- visible=rag_mode,
57
- interactive=False,
58
- label='User Message With Context',
59
- lines=num_lines,
60
- )
61
-
62
-
63
- def get_system_prompt_component(interactive: bool) -> gr.Textbox:
64
- value = '' if interactive else 'System prompt is not supported by this model'
65
- return gr.Textbox(value=value, label='System prompt', interactive=interactive)
66
-
67
-
68
- def get_generate_args(do_sample: bool) -> List[gr.component]:
69
- generate_args = [
70
- gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
71
- gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
72
- gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
73
- gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
74
- ]
75
- return generate_args
76
-
77
-
78
- def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox:
79
- value = visible = db is not None
80
- return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible)
81
-
82
-
83
- # ================ LOADING AND INITIALIZING MODELS ========================
84
-
85
- start_llm_model, start_support_system_role, load_log = load_llm_model(LLM_MODEL_REPOS[0], 'gemma-2-2b-it-Q8_0.gguf')
86
- start_embed_model, load_log = load_embed_model(EMBED_MODEL_REPOS[0])
87
-
88
-
89
-
90
- # ================== APPLICATION WEB INTERFACE ============================
91
-
92
- css = '''.gradio-container {width: 60% !important}'''
93
-
94
- with gr.Blocks(css=css) as interface:
95
-
96
- # ==================== GRADIO STATES ===============================
97
-
98
- documents = gr.State([])
99
- db = gr.State(None)
100
- user_message_with_context = gr.State('')
101
- support_system_role = gr.State(start_support_system_role)
102
- llm_model_repos = gr.State(LLM_MODEL_REPOS)
103
- embed_model_repos = gr.State(EMBED_MODEL_REPOS)
104
- llm_model = gr.State(start_llm_model)
105
- embed_model = gr.State(start_embed_model)
106
-
107
-
108
-
109
- # ==================== BOT PAGE =================================
110
-
111
- with gr.Tab(label='Chatbot'):
112
- with gr.Row():
113
- with gr.Column(scale=3):
114
- chatbot = gr.Chatbot(
115
- type='messages', # new in gradio 5+
116
- show_copy_button=True,
117
- bubble_full_width=False,
118
- height=480,
119
- )
120
- user_message = gr.Textbox(label='User')
121
-
122
- with gr.Row():
123
- user_message_btn = gr.Button('Send')
124
- stop_btn = gr.Button('Stop')
125
- clear_btn = gr.Button('Clear')
126
-
127
- # ------------- GENERATION PARAMETERS -------------------
128
-
129
- with gr.Column(scale=1, min_width=80):
130
- with gr.Group():
131
- gr.Markdown('History size')
132
- history_len = gr.Slider(
133
- minimum=0,
134
- maximum=5,
135
- value=0,
136
- step=1,
137
- info='Number of previous messages taken into account in history',
138
- label='history_len',
139
- show_label=False,
140
- )
141
-
142
- with gr.Group():
143
- gr.Markdown('Generation parameters')
144
- do_sample = gr.Checkbox(
145
- value=False,
146
- label='do_sample',
147
- info='Activate random sampling',
148
- )
149
- generate_args = get_generate_args(do_sample.value)
150
- do_sample.change(
151
- fn=get_generate_args,
152
- inputs=do_sample,
153
- outputs=generate_args,
154
- show_progress=False,
155
- )
156
-
157
- rag_mode = get_rag_mode_component(db=db.value)
158
- k, score_threshold = get_rag_settings(rag_mode=rag_mode.value, render=False)
159
- rag_mode.change(
160
- fn=get_rag_settings,
161
- inputs=[rag_mode],
162
- outputs=[k, score_threshold],
163
- )
164
- with gr.Row():
165
- k.render()
166
- score_threshold.render()
167
-
168
- # ---------------- SYSTEM PROMPT AND USER MESSAGE -----------
169
-
170
- with gr.Accordion('Prompt', open=True):
171
- system_prompt = get_system_prompt_component(interactive=support_system_role.value)
172
- user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value)
173
-
174
- # ---------------- SEND, CLEAR AND STOP BUTTONS ------------
175
-
176
- generate_event = gr.on(
177
- triggers=[user_message.submit, user_message_btn.click],
178
- fn=user_message_to_chatbot,
179
- inputs=[user_message, chatbot],
180
- outputs=[user_message, chatbot],
181
- queue=False,
182
- ).then(
183
- fn=update_user_message_with_context,
184
- inputs=[chatbot, rag_mode, db, k, score_threshold],
185
- outputs=[user_message_with_context],
186
- ).then(
187
- fn=get_user_message_with_context,
188
- inputs=[user_message_with_context, rag_mode],
189
- outputs=[user_message_with_context],
190
- ).then(
191
- fn=get_llm_response,
192
- inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt,
193
- support_system_role, history_len, do_sample, *generate_args],
194
- outputs=[chatbot],
195
- )
196
-
197
- stop_btn.click(
198
- fn=None,
199
- inputs=None,
200
- outputs=None,
201
- cancels=generate_event,
202
- queue=False,
203
- )
204
-
205
- clear_btn.click(
206
- fn=lambda: (None, ''),
207
- inputs=None,
208
- outputs=[chatbot, user_message_with_context],
209
- queue=False,
210
- )
211
-
212
-
213
-
214
- # ================= FILE DOWNLOAD PAGE =========================
215
-
216
- with gr.Tab(label='Load documents'):
217
- with gr.Row(variant='compact'):
218
- upload_files = gr.File(file_count='multiple', label='Loading text files')
219
- web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube')
220
-
221
- with gr.Row(variant='compact'):
222
- chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size')
223
- chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap')
224
-
225
- subtitles_lang = gr.Radio(
226
- SUBTITLES_LANGUAGES,
227
- value=SUBTITLES_LANGUAGES[0],
228
- label='YouTube subtitle language',
229
- )
230
-
231
- load_documents_btn = gr.Button(value='Upload documents and initialize database')
232
- load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False)
233
-
234
- load_documents_btn.click(
235
- fn=load_documents_and_create_db,
236
- inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model],
237
- outputs=[documents, db, load_docs_log],
238
- ).success(
239
- fn=get_rag_mode_component,
240
- inputs=[db],
241
- outputs=[rag_mode],
242
- )
243
-
244
- gr.HTML("""<h3 style='text-align: center'>
245
- <a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3>
246
- """)
247
-
248
-
249
-
250
- # ================= VIEW PAGE FOR ALL DOCUMENTS =================
251
-
252
- with gr.Tab(label='View documents'):
253
- view_documents_btn = gr.Button(value='Show downloaded text chunks')
254
- view_documents_textbox = gr.Textbox(
255
- lines=1,
256
- placeholder='To view chunks, load documents in the Load documents tab',
257
- label='Uploaded chunks',
258
- )
259
- sep = '=' * 20
260
- view_documents_btn.click(
261
- lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]),
262
- inputs=[documents],
263
- outputs=[view_documents_textbox],
264
- )
265
-
266
-
267
- # ============== GGUF MODELS DOWNLOAD PAGE =====================
268
-
269
- with gr.Tab('Load LLM model'):
270
- new_llm_model_repo = gr.Textbox(
271
- value='',
272
- label='Add repository',
273
- placeholder='Link to repository of HF models in GGUF format',
274
- )
275
- new_llm_model_repo_btn = gr.Button('Add repository')
276
- curr_llm_model_repo = gr.Dropdown(
277
- choices=LLM_MODEL_REPOS,
278
- value=None,
279
- label='HF Model Repository',
280
- )
281
- curr_llm_model_path = gr.Dropdown(
282
- choices=[],
283
- value=None,
284
- label='GGUF model file',
285
- )
286
- load_llm_model_btn = gr.Button('Loading and initializing model')
287
- load_llm_model_log = gr.Textbox(
288
- value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup',
289
- label='Model loading status',
290
- lines=6,
291
- )
292
-
293
- with gr.Group():
294
- gr.Markdown('Free up disk space by deleting all models except the currently selected one')
295
- clear_llm_folder_btn = gr.Button('Clear folder')
296
-
297
- new_llm_model_repo_btn.click(
298
- fn=add_new_model_repo,
299
- inputs=[new_llm_model_repo, llm_model_repos],
300
- outputs=[curr_llm_model_repo, load_llm_model_log],
301
- ).success(
302
- fn=lambda: '',
303
- inputs=None,
304
- outputs=[new_llm_model_repo],
305
- )
306
-
307
- curr_llm_model_repo.change(
308
- fn=get_gguf_model_names,
309
- inputs=[curr_llm_model_repo],
310
- outputs=[curr_llm_model_path],
311
- )
312
-
313
- load_llm_model_btn.click(
314
- fn=load_llm_model,
315
- inputs=[curr_llm_model_repo, curr_llm_model_path],
316
- outputs=[llm_model, support_system_role, load_llm_model_log],
317
- ).success(
318
- fn=lambda log: log + get_memory_usage(),
319
- inputs=[load_llm_model_log],
320
- outputs=[load_llm_model_log],
321
- ).then(
322
- fn=get_system_prompt_component,
323
- inputs=[support_system_role],
324
- outputs=[system_prompt],
325
- )
326
-
327
- clear_llm_folder_btn.click(
328
- fn=clear_llm_folder,
329
- inputs=[curr_llm_model_path],
330
- outputs=None,
331
- ).success(
332
- fn=lambda model_path: f'Models other than {model_path} removed',
333
- inputs=[curr_llm_model_path],
334
- outputs=None,
335
- )
336
-
337
-
338
- # ============== EMBEDDING MODELS DOWNLOAD PAGE =============
339
-
340
- with gr.Tab('Load embed model'):
341
- new_embed_model_repo = gr.Textbox(
342
- value='',
343
- label='Add repository',
344
- placeholder='Link to HF model repository',
345
- )
346
- new_embed_model_repo_btn = gr.Button('Add repository')
347
- curr_embed_model_repo = gr.Dropdown(
348
- choices=EMBED_MODEL_REPOS,
349
- value=None,
350
- label='HF model repository',
351
- )
352
-
353
- load_embed_model_btn = gr.Button('Loading and initializing model')
354
- load_embed_model_log = gr.Textbox(
355
- value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup',
356
- label='Model loading status',
357
- lines=7,
358
- )
359
- with gr.Group():
360
- gr.Markdown('Free up disk space by deleting all models except the currently selected one')
361
- clear_embed_folder_btn = gr.Button('Clear folder')
362
-
363
- new_embed_model_repo_btn.click(
364
- fn=add_new_model_repo,
365
- inputs=[new_embed_model_repo, embed_model_repos],
366
- outputs=[curr_embed_model_repo, load_embed_model_log],
367
- ).success(
368
- fn=lambda: '',
369
- inputs=None,
370
- outputs=new_embed_model_repo,
371
- )
372
-
373
- load_embed_model_btn.click(
374
- fn=load_embed_model,
375
- inputs=[curr_embed_model_repo],
376
- outputs=[embed_model, load_embed_model_log],
377
- ).success(
378
- fn=lambda log: log + get_memory_usage(),
379
- inputs=[load_embed_model_log],
380
- outputs=[load_embed_model_log],
381
- )
382
-
383
- clear_embed_folder_btn.click(
384
- fn=clear_embed_folder,
385
- inputs=[curr_embed_model_repo],
386
- outputs=None,
387
- ).success(
388
- fn=lambda model_repo: f'Models other than {model_repo} removed',
389
- inputs=[curr_embed_model_repo],
390
- outputs=None,
391
- )
392
-
393
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  interface.launch(server_name='0.0.0.0', server_port=7860) # debug=True
 
1
+ from typing import List, Tuple, Optional
2
+
3
+ import gradio as gr
4
+ from langchain_core.vectorstores import VectorStore
5
+
6
+ from config import (
7
+ LLM_MODEL_REPOS,
8
+ EMBED_MODEL_REPOS,
9
+ SUBTITLES_LANGUAGES,
10
+ GENERATE_KWARGS,
11
+ CONTEXT_TEMPLATE,
12
+ )
13
+
14
+ from utils import (
15
+ load_llm_model,
16
+ load_embed_model,
17
+ load_documents_and_create_db,
18
+ user_message_to_chatbot,
19
+ update_user_message_with_context,
20
+ get_llm_response,
21
+ get_gguf_model_names,
22
+ add_new_model_repo,
23
+ clear_llm_folder,
24
+ clear_embed_folder,
25
+ get_memory_usage,
26
+ )
27
+
28
+
29
+ # ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============
30
+
31
+ def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox:
32
+ value = visible = db is not None
33
+ return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible)
34
+
35
+
36
+ def get_rag_settings(
37
+ rag_mode: bool,
38
+ context_template_value: str,
39
+ render: bool = True,
40
+ ) -> Tuple[gr.component, ...]:
41
+
42
+ k = gr.Radio(
43
+ choices=[1, 2, 3, 4, 5, 'all'],
44
+ value=2,
45
+ label='Number of relevant documents for search',
46
+ visible=rag_mode,
47
+ render=render,
48
+ )
49
+ score_threshold = gr.Slider(
50
+ minimum=0,
51
+ maximum=1,
52
+ value=0.5,
53
+ step=0.05,
54
+ label='relevance_scores_threshold',
55
+ visible=rag_mode,
56
+ render=render,
57
+ )
58
+ context_template = gr.Textbox(
59
+ value=context_template_value,
60
+ label='Context Template',
61
+ lines=len(context_template_value.split('\n')),
62
+ visible=rag_mode,
63
+ render=render,
64
+ )
65
+ return k, score_threshold, context_template
66
+
67
+
68
+ def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component:
69
+ num_lines = len(text.split('\n'))
70
+ max_lines = 10
71
+ num_lines = max_lines if num_lines > max_lines else num_lines
72
+ return gr.Textbox(
73
+ text,
74
+ visible=rag_mode,
75
+ interactive=False,
76
+ label='User Message With Context',
77
+ lines=num_lines,
78
+ )
79
+
80
+
81
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
82
+ value = '' if interactive else 'System prompt is not supported by this model'
83
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
84
+
85
+
86
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
87
+ generate_args = [
88
+ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
89
+ gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
90
+ gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
91
+ gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
92
+ ]
93
+ return generate_args
94
+
95
+
96
+ # ================ LOADING AND INITIALIZING MODELS ========================
97
+
98
+ start_llm_model, start_support_system_role, load_log = load_llm_model(LLM_MODEL_REPOS[0], 'gemma-2-2b-it-Q8_0.gguf')
99
+ start_embed_model, load_log = load_embed_model(EMBED_MODEL_REPOS[0])
100
+
101
+
102
+
103
+ # ================== APPLICATION WEB INTERFACE ============================
104
+
105
+ css = '''.gradio-container {width: 60% !important}'''
106
+
107
+ with gr.Blocks(css=css) as interface:
108
+
109
+ # ==================== GRADIO STATES ===============================
110
+
111
+ documents = gr.State([])
112
+ db = gr.State(None)
113
+ user_message_with_context = gr.State('')
114
+ support_system_role = gr.State(start_support_system_role)
115
+ llm_model_repos = gr.State(LLM_MODEL_REPOS)
116
+ embed_model_repos = gr.State(EMBED_MODEL_REPOS)
117
+ llm_model = gr.State(start_llm_model)
118
+ embed_model = gr.State(start_embed_model)
119
+
120
+
121
+
122
+ # ==================== BOT PAGE =================================
123
+
124
+ with gr.Tab(label='Chatbot'):
125
+ with gr.Row():
126
+ with gr.Column(scale=3):
127
+ chatbot = gr.Chatbot(
128
+ type='messages', # new in gradio 5+
129
+ show_copy_button=True,
130
+ bubble_full_width=False,
131
+ height=480,
132
+ )
133
+ user_message = gr.Textbox(label='User')
134
+
135
+ with gr.Row():
136
+ user_message_btn = gr.Button('Send')
137
+ stop_btn = gr.Button('Stop')
138
+ clear_btn = gr.Button('Clear')
139
+
140
+ # ------------- GENERATION PARAMETERS -------------------
141
+
142
+ with gr.Column(scale=1, min_width=80):
143
+ with gr.Group():
144
+ gr.Markdown('History size')
145
+ history_len = gr.Slider(
146
+ minimum=0,
147
+ maximum=5,
148
+ value=0,
149
+ step=1,
150
+ info='Number of previous messages taken into account in history',
151
+ label='history_len',
152
+ show_label=False,
153
+ )
154
+
155
+ with gr.Group():
156
+ gr.Markdown('Generation parameters')
157
+ do_sample = gr.Checkbox(
158
+ value=False,
159
+ label='do_sample',
160
+ info='Activate random sampling',
161
+ )
162
+ generate_args = get_generate_args(do_sample.value)
163
+ do_sample.change(
164
+ fn=get_generate_args,
165
+ inputs=do_sample,
166
+ outputs=generate_args,
167
+ show_progress=False,
168
+ )
169
+
170
+ rag_mode = get_rag_mode_component(db=db.value)
171
+ k, score_threshold, context_template = get_rag_settings(
172
+ rag_mode=rag_mode.value,
173
+ context_template_value=CONTEXT_TEMPLATE,
174
+ render=False,
175
+ )
176
+ rag_mode.change(
177
+ fn=get_rag_settings,
178
+ inputs=[rag_mode, context_template],
179
+ outputs=[k, score_threshold, context_template],
180
+ )
181
+
182
+ with gr.Row():
183
+ k.render()
184
+ score_threshold.render()
185
+
186
+ # ---------------- SYSTEM PROMPT AND USER MESSAGE -----------
187
+
188
+ with gr.Accordion('Prompt', open=True):
189
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
190
+ context_template.render()
191
+ user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value)
192
+
193
+ # ---------------- SEND, CLEAR AND STOP BUTTONS ------------
194
+
195
+ generate_event = gr.on(
196
+ triggers=[user_message.submit, user_message_btn.click],
197
+ fn=user_message_to_chatbot,
198
+ inputs=[user_message, chatbot],
199
+ outputs=[user_message, chatbot],
200
+ queue=False,
201
+ ).then(
202
+ fn=update_user_message_with_context,
203
+ inputs=[chatbot, rag_mode, db, k, score_threshold, context_template],
204
+ outputs=[user_message_with_context],
205
+ ).then(
206
+ fn=get_user_message_with_context,
207
+ inputs=[user_message_with_context, rag_mode],
208
+ outputs=[user_message_with_context],
209
+ ).then(
210
+ fn=get_llm_response,
211
+ inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt,
212
+ support_system_role, history_len, do_sample, *generate_args],
213
+ outputs=[chatbot],
214
+ )
215
+
216
+ stop_btn.click(
217
+ fn=None,
218
+ inputs=None,
219
+ outputs=None,
220
+ cancels=generate_event,
221
+ queue=False,
222
+ )
223
+
224
+ clear_btn.click(
225
+ fn=lambda: (None, ''),
226
+ inputs=None,
227
+ outputs=[chatbot, user_message_with_context],
228
+ queue=False,
229
+ )
230
+
231
+
232
+
233
+ # ================= FILE DOWNLOAD PAGE =========================
234
+
235
+ with gr.Tab(label='Load documents'):
236
+ with gr.Row(variant='compact'):
237
+ upload_files = gr.File(file_count='multiple', label='Loading text files')
238
+ web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube')
239
+
240
+ with gr.Row(variant='compact'):
241
+ chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size')
242
+ chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap')
243
+
244
+ subtitles_lang = gr.Radio(
245
+ SUBTITLES_LANGUAGES,
246
+ value=SUBTITLES_LANGUAGES[0],
247
+ label='YouTube subtitle language',
248
+ )
249
+
250
+ load_documents_btn = gr.Button(value='Upload documents and initialize database')
251
+ load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False)
252
+
253
+ load_documents_btn.click(
254
+ fn=load_documents_and_create_db,
255
+ inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model],
256
+ outputs=[documents, db, load_docs_log],
257
+ ).success(
258
+ fn=get_rag_mode_component,
259
+ inputs=[db],
260
+ outputs=[rag_mode],
261
+ )
262
+
263
+ gr.HTML("""<h3 style='text-align: center'>
264
+ <a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3>
265
+ """)
266
+
267
+
268
+
269
+ # ================= VIEW PAGE FOR ALL DOCUMENTS =================
270
+
271
+ with gr.Tab(label='View documents'):
272
+ view_documents_btn = gr.Button(value='Show downloaded text chunks')
273
+ view_documents_textbox = gr.Textbox(
274
+ lines=1,
275
+ placeholder='To view chunks, load documents in the Load documents tab',
276
+ label='Uploaded chunks',
277
+ )
278
+ sep = '=' * 20
279
+ view_documents_btn.click(
280
+ lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]),
281
+ inputs=[documents],
282
+ outputs=[view_documents_textbox],
283
+ )
284
+
285
+
286
+ # ============== GGUF MODELS DOWNLOAD PAGE =====================
287
+
288
+ with gr.Tab('Load LLM model'):
289
+ new_llm_model_repo = gr.Textbox(
290
+ value='',
291
+ label='Add repository',
292
+ placeholder='Link to repository of HF models in GGUF format',
293
+ )
294
+ new_llm_model_repo_btn = gr.Button('Add repository')
295
+ curr_llm_model_repo = gr.Dropdown(
296
+ choices=LLM_MODEL_REPOS,
297
+ value=None,
298
+ label='HF Model Repository',
299
+ )
300
+ curr_llm_model_path = gr.Dropdown(
301
+ choices=[],
302
+ value=None,
303
+ label='GGUF model file',
304
+ )
305
+ load_llm_model_btn = gr.Button('Loading and initializing model')
306
+ load_llm_model_log = gr.Textbox(
307
+ value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup',
308
+ label='Model loading status',
309
+ lines=6,
310
+ )
311
+
312
+ with gr.Group():
313
+ gr.Markdown('Free up disk space by deleting all models except the currently selected one')
314
+ clear_llm_folder_btn = gr.Button('Clear folder')
315
+
316
+ new_llm_model_repo_btn.click(
317
+ fn=add_new_model_repo,
318
+ inputs=[new_llm_model_repo, llm_model_repos],
319
+ outputs=[curr_llm_model_repo, load_llm_model_log],
320
+ ).success(
321
+ fn=lambda: '',
322
+ inputs=None,
323
+ outputs=[new_llm_model_repo],
324
+ )
325
+
326
+ curr_llm_model_repo.change(
327
+ fn=get_gguf_model_names,
328
+ inputs=[curr_llm_model_repo],
329
+ outputs=[curr_llm_model_path],
330
+ )
331
+
332
+ load_llm_model_btn.click(
333
+ fn=load_llm_model,
334
+ inputs=[curr_llm_model_repo, curr_llm_model_path],
335
+ outputs=[llm_model, support_system_role, load_llm_model_log],
336
+ ).success(
337
+ fn=lambda log: log + get_memory_usage(),
338
+ inputs=[load_llm_model_log],
339
+ outputs=[load_llm_model_log],
340
+ ).then(
341
+ fn=get_system_prompt_component,
342
+ inputs=[support_system_role],
343
+ outputs=[system_prompt],
344
+ )
345
+
346
+ clear_llm_folder_btn.click(
347
+ fn=clear_llm_folder,
348
+ inputs=[curr_llm_model_path],
349
+ outputs=None,
350
+ ).success(
351
+ fn=lambda model_path: f'Models other than {model_path} removed',
352
+ inputs=[curr_llm_model_path],
353
+ outputs=None,
354
+ )
355
+
356
+
357
+ # ============== EMBEDDING MODELS DOWNLOAD PAGE =============
358
+
359
+ with gr.Tab('Load embed model'):
360
+ new_embed_model_repo = gr.Textbox(
361
+ value='',
362
+ label='Add repository',
363
+ placeholder='Link to HF model repository',
364
+ )
365
+ new_embed_model_repo_btn = gr.Button('Add repository')
366
+ curr_embed_model_repo = gr.Dropdown(
367
+ choices=EMBED_MODEL_REPOS,
368
+ value=None,
369
+ label='HF model repository',
370
+ )
371
+
372
+ load_embed_model_btn = gr.Button('Loading and initializing model')
373
+ load_embed_model_log = gr.Textbox(
374
+ value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup',
375
+ label='Model loading status',
376
+ lines=7,
377
+ )
378
+ with gr.Group():
379
+ gr.Markdown('Free up disk space by deleting all models except the currently selected one')
380
+ clear_embed_folder_btn = gr.Button('Clear folder')
381
+
382
+ new_embed_model_repo_btn.click(
383
+ fn=add_new_model_repo,
384
+ inputs=[new_embed_model_repo, embed_model_repos],
385
+ outputs=[curr_embed_model_repo, load_embed_model_log],
386
+ ).success(
387
+ fn=lambda: '',
388
+ inputs=None,
389
+ outputs=new_embed_model_repo,
390
+ )
391
+
392
+ load_embed_model_btn.click(
393
+ fn=load_embed_model,
394
+ inputs=[curr_embed_model_repo],
395
+ outputs=[embed_model, load_embed_model_log],
396
+ ).success(
397
+ fn=lambda log: log + get_memory_usage(),
398
+ inputs=[load_embed_model_log],
399
+ outputs=[load_embed_model_log],
400
+ )
401
+
402
+ clear_embed_folder_btn.click(
403
+ fn=clear_embed_folder,
404
+ inputs=[curr_embed_model_repo],
405
+ outputs=None,
406
+ ).success(
407
+ fn=lambda model_repo: f'Models other than {model_repo} removed',
408
+ inputs=[curr_embed_model_repo],
409
+ outputs=None,
410
+ )
411
+
412
+
413
  interface.launch(server_name='0.0.0.0', server_port=7860) # debug=True
utils.py CHANGED
@@ -1,500 +1,506 @@
1
- import csv
2
- from pathlib import Path
3
- from shutil import rmtree
4
- from typing import List, Tuple, Dict, Union, Optional, Any, Iterable
5
- from tqdm import tqdm
6
-
7
- import psutil
8
- import requests
9
- from requests.exceptions import MissingSchema
10
-
11
- import torch
12
- import gradio as gr
13
-
14
- from llama_cpp import Llama
15
- from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound, TranscriptsDisabled
16
- from huggingface_hub import hf_hub_download, list_repo_tree, list_repo_files, repo_info, repo_exists, snapshot_download
17
-
18
- from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
19
- from langchain_community.vectorstores import FAISS
20
- from langchain_huggingface import HuggingFaceEmbeddings
21
-
22
- # imports for annotations
23
- from langchain.docstore.document import Document
24
- from langchain_core.embeddings import Embeddings
25
- from langchain_core.vectorstores import VectorStore
26
-
27
- from config import (
28
- LLM_MODELS_PATH,
29
- EMBED_MODELS_PATH,
30
- GENERATE_KWARGS,
31
- LOADER_CLASSES,
32
- CONTEXT_TEMPLATE,
33
- )
34
-
35
-
36
- # type annotations
37
- CHAT_HISTORY = List[Optional[Dict[str, Optional[str]]]]
38
- LLM_MODEL_DICT = Dict[str, Llama]
39
- EMBED_MODEL_DICT = Dict[str, Embeddings]
40
-
41
-
42
- # ===================== ADDITIONAL FUNCS =======================
43
-
44
- # getting the amount of free memory on disk, CPU and GPU
45
- def get_memory_usage() -> str:
46
- print_memory = ''
47
-
48
- memory_type = 'Disk'
49
- psutil_stats = psutil.disk_usage('.')
50
- memory_total = psutil_stats.total / 1024**3
51
- memory_usage = psutil_stats.used / 1024**3
52
- print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
53
-
54
- memory_type = 'CPU'
55
- psutil_stats = psutil.virtual_memory()
56
- memory_total = psutil_stats.total / 1024**3
57
- memory_usage = memory_total - (psutil_stats.available / 1024**3)
58
- print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
59
-
60
- if torch.cuda.is_available():
61
- memory_type = 'GPU'
62
- memory_free, memory_total = torch.cuda.mem_get_info()
63
- memory_usage = memory_total - memory_free
64
- print_memory += f'{memory_type} Menory Usage: {memory_usage / 1024**3:.2f} / {memory_total:.2f} GB\n'
65
-
66
- print_memory = f'---------------\n{print_memory}---------------'
67
- return print_memory
68
-
69
-
70
- # clearing the list of documents
71
- def clear_documents(documents: Iterable[Document]) -> Iterable[Document]:
72
- def clear_text(text: str) -> str:
73
- lines = text.split('\n')
74
- lines = [line for line in lines if len(line.strip()) > 2]
75
- text = '\n'.join(lines).strip()
76
- return text
77
-
78
- output_documents = []
79
- for document in documents:
80
- text = clear_text(document.page_content)
81
- if len(text) > 10:
82
- document.page_content = text
83
- output_documents.append(document)
84
- return output_documents
85
-
86
-
87
- # ===================== INTERFACE FUNCS =============================
88
-
89
-
90
- # ------------- LLM AND EMBEDDING MODELS LOADING ------------------------
91
-
92
- # downloading file by URL link and displaying progress bars tqdm and gradio
93
- def download_file(file_url: str, file_path: Union[str, Path]) -> None:
94
- response = requests.get(file_url, stream=True)
95
- if response.status_code != 200:
96
- raise Exception(f'The file is not available for download at the link: {file_url}')
97
- total_size = int(response.headers.get('content-length', 0))
98
- progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
99
- progress_gradio = gr.Progress()
100
- completed_size = 0
101
- with open(file_path, 'wb') as file:
102
- for data in response.iter_content(chunk_size=4096):
103
- size = file.write(data)
104
- progress_tqdm.update(size)
105
- completed_size += size
106
- desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
107
- progress_gradio(completed_size/total_size, desc=desc)
108
-
109
-
110
- # loading and initializing the GGUF model
111
- def load_llm_model(model_repo: str, model_file: str) -> Tuple[LLM_MODEL_DICT, str, str]:
112
- llm_model = None
113
- load_log = ''
114
- support_system_role = False
115
-
116
- if isinstance(model_file, list):
117
- load_log += 'No model selected\n'
118
- return {'llm_model': llm_model}, support_system_role, load_log
119
-
120
- if '(' in model_file:
121
- model_file = model_file.split('(')[0].rstrip()
122
-
123
- progress = gr.Progress()
124
- progress(0.3, desc='Step 1/2: Download the GGUF file')
125
- model_path = LLM_MODELS_PATH / model_file
126
-
127
- if model_path.is_file():
128
- load_log += f'Model {model_file} already loaded, reinitializing\n'
129
- else:
130
- try:
131
- gguf_url = f'https://huggingface.co/{model_repo}/resolve/main/{model_file}'
132
- download_file(gguf_url, model_path)
133
- load_log += f'Model {model_file} loaded\n'
134
- except Exception as ex:
135
- model_path = ''
136
- load_log += f'Error loading model, error code:\n{ex}\n'
137
-
138
- if model_path:
139
- progress(0.7, desc='Step 2/2: Initialize the model')
140
- try:
141
- llm_model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=False)
142
- support_system_role = 'System role not supported' not in llm_model.metadata['tokenizer.chat_template']
143
- load_log += f'Model {model_file} initialized, max context size is {llm_model.n_ctx()} tokens\n'
144
- except Exception as ex:
145
- load_log += f'Error initializing model, error code:\n{ex}\n'
146
-
147
- llm_model = {'llm_model': llm_model}
148
- return llm_model, support_system_role, load_log
149
-
150
-
151
- # loading and initializing the embedding model
152
- def load_embed_model(model_repo: str) -> Tuple[Dict[str, HuggingFaceEmbeddings], str]:
153
- embed_model = None
154
- load_log = ''
155
-
156
- if isinstance(model_repo, list):
157
- load_log = 'No model selected'
158
- return embed_model, load_log
159
-
160
- progress = gr.Progress()
161
- folder_name = model_repo.replace('/', '_')
162
- folder_path = EMBED_MODELS_PATH / folder_name
163
- if Path(folder_path).is_dir():
164
- load_log += f'Reinitializing model {model_repo} \n'
165
- else:
166
- progress(0.5, desc='Step 1/2: Download model repository')
167
- snapshot_download(
168
- repo_id=model_repo,
169
- local_dir=folder_path,
170
- ignore_patterns='*.h5',
171
- )
172
- load_log += f'Model {model_repo} loaded\n'
173
-
174
- progress(0.7, desc='Шаг 2/2: Инициализация модели')
175
- model_kwargs = {'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
176
- embed_model = HuggingFaceEmbeddings(
177
- model_name=str(folder_path),
178
- model_kwargs=model_kwargs,
179
- # encode_kwargs={'normalize_embeddings': True},
180
- )
181
- load_log += f'Embeddings model {model_repo} initialized\n'
182
- load_log += f'Please upload documents and initialize database again\n'
183
- embed_model = {'embed_model': embed_model}
184
- return embed_model, load_log
185
-
186
-
187
- # adding a new HF repository new_model_repo to the current list of model_repos
188
- def add_new_model_repo(new_model_repo: str, model_repos: List[str]) -> Tuple[gr.Dropdown, str]:
189
- load_log = ''
190
- repo = new_model_repo.strip()
191
- if repo:
192
- repo = repo.split('/')[-2:]
193
- if len(repo) == 2:
194
- repo = '/'.join(repo).split('?')[0]
195
- if repo_exists(repo) and repo not in model_repos:
196
- model_repos.insert(0, repo)
197
- load_log += f'Model repository {repo} successfully added\n'
198
- else:
199
- load_log += 'Invalid HF repository name or model already in the list\n'
200
- else:
201
- load_log += 'Invalid link to HF repository\n'
202
- else:
203
- load_log += 'Empty line in HF repository field\n'
204
- model_repo_dropdown = gr.Dropdown(choices=model_repos, value=model_repos[0])
205
- return model_repo_dropdown, load_log
206
-
207
-
208
- # get list of GGUF models from HF repository
209
- def get_gguf_model_names(model_repo: str) -> gr.Dropdown:
210
- repo_files = list(list_repo_tree(model_repo))
211
- repo_files = [file for file in repo_files if file.path.endswith('.gguf')]
212
- model_paths = [f'{file.path} ({file.size / 1000 ** 3:.2f}G)' for file in repo_files]
213
- model_paths_dropdown = gr.Dropdown(
214
- choices=model_paths,
215
- value=model_paths[0],
216
- label='GGUF model file',
217
- )
218
- return model_paths_dropdown
219
-
220
-
221
- # delete model files and folders to clear space except for the current model gguf_filename
222
- def clear_llm_folder(gguf_filename: str) -> None:
223
- if gguf_filename is None:
224
- gr.Info(f'The name of the model file that does not need to be deleted is not selected.')
225
- return
226
- if '(' in gguf_filename:
227
- gguf_filename = gguf_filename.split('(')[0].rstrip()
228
- for path in LLM_MODELS_PATH.iterdir():
229
- if path.name == gguf_filename:
230
- continue
231
- if path.is_file():
232
- path.unlink(missing_ok=True)
233
- gr.Info(f'All files removed from directory {LLM_MODELS_PATH} except {gguf_filename}')
234
-
235
-
236
- # delete model folders to clear space except for the current model model_folder_name
237
- def clear_embed_folder(model_repo: str) -> None:
238
- if model_repo is None:
239
- gr.Info(f'The name of the model that does not need to be deleted is not selected.')
240
- return
241
- model_folder_name = model_repo.replace('/', '_')
242
- for path in EMBED_MODELS_PATH.iterdir():
243
- if path.name == model_folder_name:
244
- continue
245
- if path.is_dir():
246
- rmtree(path, ignore_errors=True)
247
- gr.Info(f'All directories have been removed from the {EMBED_MODELS_PATH} directory except {model_folder_name}')
248
-
249
-
250
- # ------------------------ YOUTUBE ------------------------
251
-
252
- # function to check availability of subtitles, if manual or automatic are available - returns True and logs
253
- # if subtitles are not available - returns False and logs
254
- def check_subtitles_available(yt_video_link: str, target_lang: str) -> Tuple[bool, str]:
255
- video_id = yt_video_link.split('watch?v=')[-1].split('&')[0]
256
- load_log = ''
257
- available = True
258
- try:
259
- transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
260
- try:
261
- transcript = transcript_list.find_transcript([target_lang])
262
- if transcript.is_generated:
263
- load_log += f'Automatic subtitles will be loaded, manual ones are not available for video {yt_video_link}\n'
264
- else:
265
- load_log += f'Manual subtitles will be downloaded for the video {yt_video_link}\n'
266
- except NoTranscriptFound:
267
- load_log += f'Subtitle language {target_lang} is not available for video {yt_video_link}\n'
268
- available = False
269
- except TranscriptsDisabled:
270
- load_log += f'Invalid video url ({yt_video_link}) or current server IP is blocked for YouTube\n'
271
- available = False
272
- return available, load_log
273
-
274
-
275
- # ------------- UPLOADING DOCUMENTS FOR RAG ------------------------
276
-
277
- # extract documents (in langchain Documents format) from downloaded files
278
- def load_documents_from_files(upload_files: List[str]) -> Tuple[List[Document], str]:
279
- load_log = ''
280
- documents = []
281
- for upload_file in upload_files:
282
- file_extension = f".{upload_file.split('.')[-1]}"
283
- if file_extension in LOADER_CLASSES:
284
- loader_class = LOADER_CLASSES[file_extension]
285
- loader_kwargs = {}
286
- if file_extension == '.csv':
287
- with open(upload_file) as csvfile:
288
- delimiter = csv.Sniffer().sniff(csvfile.read(4096)).delimiter
289
- loader_kwargs = {'csv_args': {'delimiter': delimiter}}
290
- try:
291
- load_documents = loader_class(upload_file, **loader_kwargs).load()
292
- documents.extend(load_documents)
293
- except Exception as ex:
294
- load_log += f'Error uploading file {upload_file}\n'
295
- load_log += f'Error code: {ex}\n'
296
- continue
297
- else:
298
- load_log += f'Unsupported file format {upload_file}\n'
299
- continue
300
- return documents, load_log
301
-
302
-
303
- # extracting documents (in langchain Documents format) from WEB links
304
- def load_documents_from_links(
305
- web_links: str,
306
- subtitles_lang: str,
307
- ) -> Tuple[List[Document], str]:
308
-
309
- load_log = ''
310
- documents = []
311
- loader_class_kwargs = {}
312
- web_links = [web_link.strip() for web_link in web_links.split('\n') if web_link.strip()]
313
- for web_link in web_links:
314
- if 'youtube.com' in web_link:
315
- available, log = check_subtitles_available(web_link, subtitles_lang)
316
- load_log += log
317
- if not available:
318
- continue
319
- loader_class = LOADER_CLASSES['youtube'].from_youtube_url
320
- loader_class_kwargs = {'language': subtitles_lang}
321
- else:
322
- loader_class = LOADER_CLASSES['web']
323
-
324
- try:
325
- if requests.get(web_link).status_code != 200:
326
- load_log += f'Ссылка недоступна для Python requests: {web_link}\n'
327
- continue
328
- load_documents = loader_class(web_link, **loader_class_kwargs).load()
329
- if len(load_documents) == 0:
330
- load_log += f'No text chunks were found at the link: {web_link}\n'
331
- continue
332
- documents.extend(load_documents)
333
- except MissingSchema:
334
- load_log += f'Invalid link: {web_link}\n'
335
- continue
336
- except Exception as ex:
337
- load_log += f'Error loading data by web loader at link: {web_link}\n'
338
- load_log += f'Error code: {ex}\n'
339
- continue
340
- return documents, load_log
341
-
342
-
343
- # uploading files and generating documents and databases
344
- def load_documents_and_create_db(
345
- upload_files: Optional[List[str]],
346
- web_links: str,
347
- subtitles_lang: str,
348
- chunk_size: int,
349
- chunk_overlap: int,
350
- embed_model_dict: EMBED_MODEL_DICT,
351
- ) -> Tuple[List[Document], Optional[VectorStore], str]:
352
-
353
- load_log = ''
354
- all_documents = []
355
- db = None
356
- progress = gr.Progress()
357
-
358
- embed_model = embed_model_dict.get('embed_model')
359
- if embed_model is None:
360
- load_log += 'Embeddings model not initialized, DB cannot be created'
361
- return all_documents, db, load_log
362
-
363
- if upload_files is None and not web_links:
364
- load_log = 'No files or links selected'
365
- return all_documents, db, load_log
366
-
367
- if upload_files is not None:
368
- progress(0.3, desc='Step 1/2: Upload documents from files')
369
- docs, log = load_documents_from_files(upload_files)
370
- all_documents.extend(docs)
371
- load_log += log
372
-
373
- if web_links:
374
- progress(0.3 if upload_files is None else 0.5, desc='Step 1/2: Upload documents via links')
375
- docs, log = load_documents_from_links(web_links, subtitles_lang)
376
- all_documents.extend(docs)
377
- load_log += log
378
-
379
- if len(all_documents) == 0:
380
- load_log += 'Download was interrupted because no documents were extracted\n'
381
- load_log += 'RAG mode cannot be activated'
382
- return all_documents, db, load_log
383
-
384
- load_log += f'Documents loaded: {len(all_documents)}\n'
385
- text_splitter = RecursiveCharacterTextSplitter(
386
- chunk_size=chunk_size,
387
- chunk_overlap=chunk_overlap,
388
- )
389
- documents = text_splitter.split_documents(all_documents)
390
- documents = clear_documents(documents)
391
- load_log += f'Documents are divided, number of text chunks: {len(documents)}\n'
392
-
393
- progress(0.7, desc='Step 2/2: Initialize DB')
394
- db = FAISS.from_documents(documents=documents, embedding=embed_model)
395
- load_log += 'DB is initialized, RAG mode is activated and can be activated in the Chatbot tab'
396
- return documents, db, load_log
397
-
398
-
399
- # ------------------ ФУНКЦИИ ЧАТ БОТА ------------------------
400
-
401
- # adding a user message to the chat bot window
402
- def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
403
- chatbot.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
404
- return '', chatbot
405
-
406
-
407
- # formatting prompt with adding context if DB is available and RAG mode is enabled
408
- def update_user_message_with_context(
409
- chatbot: CHAT_HISTORY,
410
- rag_mode: bool,
411
- db: VectorStore,
412
- k: Union[int, str],
413
- score_threshold: float,
414
- ) -> Tuple[str, CHAT_HISTORY]:
415
-
416
- user_message = chatbot[-1]['content']
417
- user_message_with_context = ''
418
- if db is not None and rag_mode and user_message.strip():
419
- if k == 'all':
420
- k = len(db.docstore._dict)
421
- docs_and_distances = db.similarity_search_with_relevance_scores(
422
- user_message,
423
- k=k,
424
- score_threshold=score_threshold,
425
- )
426
- if len(docs_and_distances) > 0:
427
- retriever_context = '\n\n'.join([doc[0].page_content for doc in docs_and_distances])
428
- user_message_with_context = CONTEXT_TEMPLATE.format(
429
- user_message=user_message,
430
- context=retriever_context,
431
- )
432
- return user_message_with_context
433
-
434
-
435
- # model response generation
436
- def get_llm_response(
437
- chatbot: CHAT_HISTORY,
438
- llm_model_dict: LLM_MODEL_DICT,
439
- user_message_with_context: str,
440
- rag_mode: bool,
441
- system_prompt: str,
442
- support_system_role: bool,
443
- history_len: int,
444
- do_sample: bool,
445
- *generate_args,
446
- ) -> CHAT_HISTORY:
447
-
448
- llm_model = llm_model_dict.get('llm_model')
449
- if llm_model is None:
450
- gr.Info('Model not initialized')
451
- yield chatbot[:-1]
452
- return
453
-
454
- gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
455
- gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
456
- if not do_sample:
457
- gen_kwargs['top_p'] = 0.0
458
- gen_kwargs['top_k'] = 1
459
- gen_kwargs['repeat_penalty'] = 1.0
460
-
461
- user_message = chatbot[-1]['content']
462
- if not user_message.strip():
463
- yield chatbot[:-1]
464
- return
465
-
466
- if rag_mode:
467
- if user_message_with_context:
468
- user_message = user_message_with_context
469
- else:
470
- gr.Info((
471
- f'No documents relevant to the query were found, generation in RAG mode is not possible.\n'
472
- f'Try reducing searh_score_threshold or disable RAG mode for normal generation'
473
- ))
474
- yield chatbot[:-1]
475
- return
476
-
477
- messages = []
478
- if support_system_role and system_prompt:
479
- messages.append({'role': 'system', 'metadata': {'title': None}, 'content': system_prompt})
480
-
481
- if history_len != 0:
482
- messages.extend(chatbot[:-1][-(history_len*2):])
483
-
484
- messages.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
485
- stream_response = llm_model.create_chat_completion(
486
- messages=messages,
487
- stream=True,
488
- **gen_kwargs,
489
- )
490
- try:
491
- chatbot.append({'role': 'assistant', 'metadata': {'title': None}, 'content': ''})
492
- for chunk in stream_response:
493
- token = chunk['choices'][0]['delta'].get('content')
494
- if token is not None:
495
- chatbot[-1]['content'] += token
496
- yield chatbot
497
- except Exception as ex:
498
- gr.Info(f'Error generating response, error code: {ex}')
499
- yield chatbot[:-1]
500
- return
 
 
 
 
 
 
 
1
+ import csv
2
+ from pathlib import Path
3
+ from shutil import rmtree
4
+ from typing import List, Tuple, Dict, Union, Optional, Any, Iterable
5
+ from tqdm import tqdm
6
+
7
+ import psutil
8
+ import requests
9
+ from requests.exceptions import MissingSchema
10
+
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from llama_cpp import Llama
15
+ from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound, TranscriptsDisabled
16
+ from huggingface_hub import hf_hub_download, list_repo_tree, list_repo_files, repo_info, repo_exists, snapshot_download
17
+
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
19
+ from langchain_community.vectorstores import FAISS
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+
22
+ # imports for annotations
23
+ from langchain.docstore.document import Document
24
+ from langchain_core.embeddings import Embeddings
25
+ from langchain_core.vectorstores import VectorStore
26
+
27
+ from config import (
28
+ LLM_MODELS_PATH,
29
+ EMBED_MODELS_PATH,
30
+ GENERATE_KWARGS,
31
+ LOADER_CLASSES,
32
+ )
33
+
34
+
35
+ # type annotations
36
+ CHAT_HISTORY = List[Optional[Dict[str, Optional[str]]]]
37
+ LLM_MODEL_DICT = Dict[str, Llama]
38
+ EMBED_MODEL_DICT = Dict[str, Embeddings]
39
+
40
+
41
+ # ===================== ADDITIONAL FUNCS =======================
42
+
43
+ # getting the amount of free memory on disk, CPU and GPU
44
+ def get_memory_usage() -> str:
45
+ print_memory = ''
46
+
47
+ memory_type = 'Disk'
48
+ psutil_stats = psutil.disk_usage('.')
49
+ memory_total = psutil_stats.total / 1024**3
50
+ memory_usage = psutil_stats.used / 1024**3
51
+ print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
52
+
53
+ memory_type = 'CPU'
54
+ psutil_stats = psutil.virtual_memory()
55
+ memory_total = psutil_stats.total / 1024**3
56
+ memory_usage = memory_total - (psutil_stats.available / 1024**3)
57
+ print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
58
+
59
+ if torch.cuda.is_available():
60
+ memory_type = 'GPU'
61
+ memory_free, memory_total = torch.cuda.mem_get_info()
62
+ memory_usage = memory_total - memory_free
63
+ print_memory += f'{memory_type} Menory Usage: {memory_usage / 1024**3:.2f} / {memory_total:.2f} GB\n'
64
+
65
+ print_memory = f'---------------\n{print_memory}---------------'
66
+ return print_memory
67
+
68
+
69
+ # clearing the list of documents
70
+ def clear_documents(documents: Iterable[Document]) -> Iterable[Document]:
71
+ def clear_text(text: str) -> str:
72
+ lines = text.split('\n')
73
+ lines = [line for line in lines if len(line.strip()) > 2]
74
+ text = '\n'.join(lines).strip()
75
+ return text
76
+
77
+ output_documents = []
78
+ for document in documents:
79
+ text = clear_text(document.page_content)
80
+ if len(text) > 10:
81
+ document.page_content = text
82
+ output_documents.append(document)
83
+ return output_documents
84
+
85
+
86
+ # ===================== INTERFACE FUNCS =============================
87
+
88
+
89
+ # ------------- LLM AND EMBEDDING MODELS LOADING ------------------------
90
+
91
+ # downloading file by URL link and displaying progress bars tqdm and gradio
92
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
93
+ response = requests.get(file_url, stream=True)
94
+ if response.status_code != 200:
95
+ raise Exception(f'The file is not available for download at the link: {file_url}')
96
+ total_size = int(response.headers.get('content-length', 0))
97
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
98
+ progress_gradio = gr.Progress()
99
+ completed_size = 0
100
+ with open(file_path, 'wb') as file:
101
+ for data in response.iter_content(chunk_size=4096):
102
+ size = file.write(data)
103
+ progress_tqdm.update(size)
104
+ completed_size += size
105
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
106
+ progress_gradio(completed_size/total_size, desc=desc)
107
+
108
+
109
+ # loading and initializing the GGUF model
110
+ def load_llm_model(model_repo: str, model_file: str) -> Tuple[LLM_MODEL_DICT, str, str]:
111
+ llm_model = None
112
+ load_log = ''
113
+ support_system_role = False
114
+
115
+ if isinstance(model_file, list):
116
+ load_log += 'No model selected\n'
117
+ return {'llm_model': llm_model}, support_system_role, load_log
118
+
119
+ if '(' in model_file:
120
+ model_file = model_file.split('(')[0].rstrip()
121
+
122
+ progress = gr.Progress()
123
+ progress(0.3, desc='Step 1/2: Download the GGUF file')
124
+ model_path = LLM_MODELS_PATH / model_file
125
+
126
+ if model_path.is_file():
127
+ load_log += f'Model {model_file} already loaded, reinitializing\n'
128
+ else:
129
+ try:
130
+ gguf_url = f'https://huggingface.co/{model_repo}/resolve/main/{model_file}'
131
+ download_file(gguf_url, model_path)
132
+ load_log += f'Model {model_file} loaded\n'
133
+ except Exception as ex:
134
+ model_path = ''
135
+ load_log += f'Error loading model, error code:\n{ex}\n'
136
+
137
+ if model_path:
138
+ progress(0.7, desc='Step 2/2: Initialize the model')
139
+ try:
140
+ llm_model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=False)
141
+ support_system_role = 'System role not supported' not in llm_model.metadata['tokenizer.chat_template']
142
+ load_log += f'Model {model_file} initialized, max context size is {llm_model.n_ctx()} tokens\n'
143
+ except Exception as ex:
144
+ load_log += f'Error initializing model, error code:\n{ex}\n'
145
+
146
+ llm_model = {'llm_model': llm_model}
147
+ return llm_model, support_system_role, load_log
148
+
149
+
150
+ # loading and initializing the embedding model
151
+ def load_embed_model(model_repo: str) -> Tuple[Dict[str, HuggingFaceEmbeddings], str]:
152
+ embed_model = None
153
+ load_log = ''
154
+
155
+ if isinstance(model_repo, list):
156
+ load_log = 'No model selected'
157
+ return embed_model, load_log
158
+
159
+ progress = gr.Progress()
160
+ folder_name = model_repo.replace('/', '_')
161
+ folder_path = EMBED_MODELS_PATH / folder_name
162
+ if Path(folder_path).is_dir():
163
+ load_log += f'Reinitializing model {model_repo} \n'
164
+ else:
165
+ progress(0.5, desc='Step 1/2: Download model repository')
166
+ snapshot_download(
167
+ repo_id=model_repo,
168
+ local_dir=folder_path,
169
+ ignore_patterns='*.h5',
170
+ )
171
+ load_log += f'Model {model_repo} loaded\n'
172
+
173
+ progress(0.7, desc='Шаг 2/2: Инициализация модели')
174
+ model_kwargs = {'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
175
+ embed_model = HuggingFaceEmbeddings(
176
+ model_name=str(folder_path),
177
+ model_kwargs=model_kwargs,
178
+ # encode_kwargs={'normalize_embeddings': True},
179
+ )
180
+ load_log += f'Embeddings model {model_repo} initialized\n'
181
+ load_log += f'Please upload documents and initialize database again\n'
182
+ embed_model = {'embed_model': embed_model}
183
+ return embed_model, load_log
184
+
185
+
186
+ # adding a new HF repository new_model_repo to the current list of model_repos
187
+ def add_new_model_repo(new_model_repo: str, model_repos: List[str]) -> Tuple[gr.Dropdown, str]:
188
+ load_log = ''
189
+ repo = new_model_repo.strip()
190
+ if repo:
191
+ repo = repo.split('/')[-2:]
192
+ if len(repo) == 2:
193
+ repo = '/'.join(repo).split('?')[0]
194
+ if repo_exists(repo) and repo not in model_repos:
195
+ model_repos.insert(0, repo)
196
+ load_log += f'Model repository {repo} successfully added\n'
197
+ else:
198
+ load_log += 'Invalid HF repository name or model already in the list\n'
199
+ else:
200
+ load_log += 'Invalid link to HF repository\n'
201
+ else:
202
+ load_log += 'Empty line in HF repository field\n'
203
+ model_repo_dropdown = gr.Dropdown(choices=model_repos, value=model_repos[0])
204
+ return model_repo_dropdown, load_log
205
+
206
+
207
+ # get list of GGUF models from HF repository
208
+ def get_gguf_model_names(model_repo: str) -> gr.Dropdown:
209
+ repo_files = list(list_repo_tree(model_repo))
210
+ repo_files = [file for file in repo_files if file.path.endswith('.gguf')]
211
+ model_paths = [f'{file.path} ({file.size / 1000 ** 3:.2f}G)' for file in repo_files]
212
+ model_paths_dropdown = gr.Dropdown(
213
+ choices=model_paths,
214
+ value=model_paths[0],
215
+ label='GGUF model file',
216
+ )
217
+ return model_paths_dropdown
218
+
219
+
220
+ # delete model files and folders to clear space except for the current model gguf_filename
221
+ def clear_llm_folder(gguf_filename: str) -> None:
222
+ if gguf_filename is None:
223
+ gr.Info(f'The name of the model file that does not need to be deleted is not selected.')
224
+ return
225
+ if '(' in gguf_filename:
226
+ gguf_filename = gguf_filename.split('(')[0].rstrip()
227
+ for path in LLM_MODELS_PATH.iterdir():
228
+ if path.name == gguf_filename:
229
+ continue
230
+ if path.is_file():
231
+ path.unlink(missing_ok=True)
232
+ gr.Info(f'All files removed from directory {LLM_MODELS_PATH} except {gguf_filename}')
233
+
234
+
235
+ # delete model folders to clear space except for the current model model_folder_name
236
+ def clear_embed_folder(model_repo: str) -> None:
237
+ if model_repo is None:
238
+ gr.Info(f'The name of the model that does not need to be deleted is not selected.')
239
+ return
240
+ model_folder_name = model_repo.replace('/', '_')
241
+ for path in EMBED_MODELS_PATH.iterdir():
242
+ if path.name == model_folder_name:
243
+ continue
244
+ if path.is_dir():
245
+ rmtree(path, ignore_errors=True)
246
+ gr.Info(f'All directories have been removed from the {EMBED_MODELS_PATH} directory except {model_folder_name}')
247
+
248
+
249
+ # ------------------------ YOUTUBE ------------------------
250
+
251
+ # function to check availability of subtitles, if manual or automatic are available - returns True and logs
252
+ # if subtitles are not available - returns False and logs
253
+ def check_subtitles_available(yt_video_link: str, target_lang: str) -> Tuple[bool, str]:
254
+ video_id = yt_video_link.split('watch?v=')[-1].split('&')[0]
255
+ load_log = ''
256
+ available = True
257
+ try:
258
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
259
+ try:
260
+ transcript = transcript_list.find_transcript([target_lang])
261
+ if transcript.is_generated:
262
+ load_log += f'Automatic subtitles will be loaded, manual ones are not available for video {yt_video_link}\n'
263
+ else:
264
+ load_log += f'Manual subtitles will be downloaded for the video {yt_video_link}\n'
265
+ except NoTranscriptFound:
266
+ load_log += f'Subtitle language {target_lang} is not available for video {yt_video_link}\n'
267
+ available = False
268
+ except TranscriptsDisabled:
269
+ load_log += f'Invalid video url ({yt_video_link}) or current server IP is blocked for YouTube\n'
270
+ available = False
271
+ return available, load_log
272
+
273
+
274
+ # ------------- UPLOADING DOCUMENTS FOR RAG ------------------------
275
+
276
+ # extract documents (in langchain Documents format) from downloaded files
277
+ def load_documents_from_files(upload_files: List[str]) -> Tuple[List[Document], str]:
278
+ load_log = ''
279
+ documents = []
280
+ for upload_file in upload_files:
281
+ file_extension = f".{upload_file.split('.')[-1]}"
282
+ if file_extension in LOADER_CLASSES:
283
+ loader_class = LOADER_CLASSES[file_extension]
284
+ loader_kwargs = {}
285
+ if file_extension == '.csv':
286
+ with open(upload_file) as csvfile:
287
+ delimiter = csv.Sniffer().sniff(csvfile.read(4096)).delimiter
288
+ loader_kwargs = {'csv_args': {'delimiter': delimiter}}
289
+ try:
290
+ load_documents = loader_class(upload_file, **loader_kwargs).load()
291
+ documents.extend(load_documents)
292
+ except Exception as ex:
293
+ load_log += f'Error uploading file {upload_file}\n'
294
+ load_log += f'Error code: {ex}\n'
295
+ continue
296
+ else:
297
+ load_log += f'Unsupported file format {upload_file}\n'
298
+ continue
299
+ return documents, load_log
300
+
301
+
302
+ # extracting documents (in langchain Documents format) from WEB links
303
+ def load_documents_from_links(
304
+ web_links: str,
305
+ subtitles_lang: str,
306
+ ) -> Tuple[List[Document], str]:
307
+
308
+ load_log = ''
309
+ documents = []
310
+ loader_class_kwargs = {}
311
+ web_links = [web_link.strip() for web_link in web_links.split('\n') if web_link.strip()]
312
+ for web_link in web_links:
313
+ if 'youtube.com' in web_link:
314
+ available, log = check_subtitles_available(web_link, subtitles_lang)
315
+ load_log += log
316
+ if not available:
317
+ continue
318
+ loader_class = LOADER_CLASSES['youtube'].from_youtube_url
319
+ loader_class_kwargs = {'language': subtitles_lang}
320
+ else:
321
+ loader_class = LOADER_CLASSES['web']
322
+
323
+ try:
324
+ if requests.get(web_link).status_code != 200:
325
+ load_log += f'Ссылка недоступна для Python requests: {web_link}\n'
326
+ continue
327
+ load_documents = loader_class(web_link, **loader_class_kwargs).load()
328
+ if len(load_documents) == 0:
329
+ load_log += f'No text chunks were found at the link: {web_link}\n'
330
+ continue
331
+ documents.extend(load_documents)
332
+ except MissingSchema:
333
+ load_log += f'Invalid link: {web_link}\n'
334
+ continue
335
+ except Exception as ex:
336
+ load_log += f'Error loading data by web loader at link: {web_link}\n'
337
+ load_log += f'Error code: {ex}\n'
338
+ continue
339
+ return documents, load_log
340
+
341
+
342
+ # uploading files and generating documents and databases
343
+ def load_documents_and_create_db(
344
+ upload_files: Optional[List[str]],
345
+ web_links: str,
346
+ subtitles_lang: str,
347
+ chunk_size: int,
348
+ chunk_overlap: int,
349
+ embed_model_dict: EMBED_MODEL_DICT,
350
+ ) -> Tuple[List[Document], Optional[VectorStore], str]:
351
+
352
+ load_log = ''
353
+ all_documents = []
354
+ db = None
355
+ progress = gr.Progress()
356
+
357
+ embed_model = embed_model_dict.get('embed_model')
358
+ if embed_model is None:
359
+ load_log += 'Embeddings model not initialized, DB cannot be created'
360
+ return all_documents, db, load_log
361
+
362
+ if upload_files is None and not web_links:
363
+ load_log = 'No files or links selected'
364
+ return all_documents, db, load_log
365
+
366
+ if upload_files is not None:
367
+ progress(0.3, desc='Step 1/2: Upload documents from files')
368
+ docs, log = load_documents_from_files(upload_files)
369
+ all_documents.extend(docs)
370
+ load_log += log
371
+
372
+ if web_links:
373
+ progress(0.3 if upload_files is None else 0.5, desc='Step 1/2: Upload documents via links')
374
+ docs, log = load_documents_from_links(web_links, subtitles_lang)
375
+ all_documents.extend(docs)
376
+ load_log += log
377
+
378
+ if len(all_documents) == 0:
379
+ load_log += 'Download was interrupted because no documents were extracted\n'
380
+ load_log += 'RAG mode cannot be activated'
381
+ return all_documents, db, load_log
382
+
383
+ load_log += f'Documents loaded: {len(all_documents)}\n'
384
+ text_splitter = RecursiveCharacterTextSplitter(
385
+ chunk_size=chunk_size,
386
+ chunk_overlap=chunk_overlap,
387
+ )
388
+ documents = text_splitter.split_documents(all_documents)
389
+ documents = clear_documents(documents)
390
+ load_log += f'Documents are divided, number of text chunks: {len(documents)}\n'
391
+
392
+ progress(0.7, desc='Step 2/2: Initialize DB')
393
+ db = FAISS.from_documents(documents=documents, embedding=embed_model)
394
+ load_log += 'DB is initialized, RAG mode is activated and can be activated in the Chatbot tab'
395
+ return documents, db, load_log
396
+
397
+
398
+ # ------------------ ФУНКЦИИ ЧАТ БОТА ------------------------
399
+
400
+ # adding a user message to the chat bot window
401
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
402
+ chatbot.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
403
+ return '', chatbot
404
+
405
+
406
+ # formatting prompt with adding context if DB is available and RAG mode is enabled
407
+ def update_user_message_with_context(
408
+ chatbot: CHAT_HISTORY,
409
+ rag_mode: bool,
410
+ db: VectorStore,
411
+ k: Union[int, str],
412
+ score_threshold: float,
413
+ context_template: str,
414
+ ) -> Tuple[str, CHAT_HISTORY]:
415
+
416
+ user_message = chatbot[-1]['content']
417
+ user_message_with_context = ''
418
+
419
+ if '{user_message}' not in context_template and '{context}' not in context_template:
420
+ gr.Info('Context template must include {user_message} and {context}')
421
+ return user_message_with_context
422
+
423
+ if db is not None and rag_mode and user_message.strip():
424
+ if k == 'all':
425
+ k = len(db.docstore._dict)
426
+ docs_and_distances = db.similarity_search_with_relevance_scores(
427
+ user_message,
428
+ k=k,
429
+ score_threshold=score_threshold,
430
+ )
431
+ if len(docs_and_distances) > 0:
432
+ retriever_context = '\n\n'.join([doc[0].page_content for doc in docs_and_distances])
433
+ user_message_with_context = context_template.format(
434
+ user_message=user_message,
435
+ context=retriever_context,
436
+ )
437
+ return user_message_with_context
438
+
439
+
440
+ # model response generation
441
+ def get_llm_response(
442
+ chatbot: CHAT_HISTORY,
443
+ llm_model_dict: LLM_MODEL_DICT,
444
+ user_message_with_context: str,
445
+ rag_mode: bool,
446
+ system_prompt: str,
447
+ support_system_role: bool,
448
+ history_len: int,
449
+ do_sample: bool,
450
+ *generate_args,
451
+ ) -> CHAT_HISTORY:
452
+
453
+ llm_model = llm_model_dict.get('llm_model')
454
+ if llm_model is None:
455
+ gr.Info('Model not initialized')
456
+ yield chatbot[:-1]
457
+ return
458
+
459
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
460
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
461
+ if not do_sample:
462
+ gen_kwargs['top_p'] = 0.0
463
+ gen_kwargs['top_k'] = 1
464
+ gen_kwargs['repeat_penalty'] = 1.0
465
+
466
+ user_message = chatbot[-1]['content']
467
+ if not user_message.strip():
468
+ yield chatbot[:-1]
469
+ return
470
+
471
+ if rag_mode:
472
+ if user_message_with_context:
473
+ user_message = user_message_with_context
474
+ else:
475
+ gr.Info((
476
+ 'No documents relevant to the query were found, generation in RAG mode is not possible.\n'
477
+ 'Or Context template is specified incorrectly.\n'
478
+ 'Try reducing searh_score_threshold or disable RAG mode for normal generation'
479
+ ))
480
+ yield chatbot[:-1]
481
+ return
482
+
483
+ messages = []
484
+ if support_system_role and system_prompt:
485
+ messages.append({'role': 'system', 'metadata': {'title': None}, 'content': system_prompt})
486
+
487
+ if history_len != 0:
488
+ messages.extend(chatbot[:-1][-(history_len*2):])
489
+
490
+ messages.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
491
+ stream_response = llm_model.create_chat_completion(
492
+ messages=messages,
493
+ stream=True,
494
+ **gen_kwargs,
495
+ )
496
+ try:
497
+ chatbot.append({'role': 'assistant', 'metadata': {'title': None}, 'content': ''})
498
+ for chunk in stream_response:
499
+ token = chunk['choices'][0]['delta'].get('content')
500
+ if token is not None:
501
+ chatbot[-1]['content'] += token
502
+ yield chatbot
503
+ except Exception as ex:
504
+ gr.Info(f'Error generating response, error code: {ex}')
505
+ yield chatbot[:-1]
506
+ return