farmax commited on
Commit
55c700d
·
verified ·
1 Parent(s): 2ed44bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -28,9 +28,12 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
28
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
29
  return text_splitter.split_documents(pages)
30
 
31
- def create_db(splits, collection_name):
32
- # Use the lightweight MiniLM model for embeddings
33
- embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
34
  new_client = chromadb.EphemeralClient()
35
  return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
36
 
@@ -67,11 +70,11 @@ def create_collection_name(filepath):
67
  collection_name = collection_name[:-1] + 'Z'
68
  return collection_name
69
 
70
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
71
  list_file_path = [x.name for x in list_file_obj if x is not None]
72
  collection_name = create_collection_name(list_file_path[0])
73
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
74
- vector_db = create_db(doc_splits, collection_name)
75
  return vector_db, collection_name, "Completed!"
76
 
77
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
@@ -113,7 +116,7 @@ def demo():
113
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF")
114
 
115
  with gr.Tab("Passo 2 - Elabora Documenti"):
116
- db_btn = gr.Radio(["ChromaDB (MiniLM Embedding)"], label="Tipo di database vettoriale", value="ChromaDB (MiniLM Embedding)", type="index")
117
  with gr.Accordion("Opzioni Avanzate - Divisione del testo del documento", open=False):
118
  slider_chunk_size = gr.Slider(100, 1000, 1000, step=20, label="Dimensione del chunk")
119
  slider_chunk_overlap = gr.Slider(10, 200, 100, step=10, label="Sovrapposizione del chunk")
@@ -133,13 +136,13 @@ def demo():
133
  with gr.Tab("Passo 4 - Chatbot"):
134
  chatbot = gr.Chatbot(height=300)
135
  with gr.Accordion("Opzioni avanzate - Riferimenti ai documenti", open=False):
136
- doc_sources = [gr.Textbox(label=f"Riferimento {i+1}", lines=2, container=True, scale=20) for i in range(5)] # Modificato da range(3) a range(5)
137
- source_pages = [gr.Number(label="Pagina", scale=1) for _ in range(5)] # Modificato da range(3) a range(5)
138
  msg = gr.Textbox(placeholder="Inserisci il messaggio (es. 'Di cosa tratta questo documento?')", container=True)
139
  submit_btn = gr.Button("Invia messaggio")
140
  clear_btn = gr.Button("Cancella conversazione")
141
 
142
- db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
143
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
144
 
145
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
 
28
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
29
  return text_splitter.split_documents(pages)
30
 
31
+ def create_db(splits, collection_name, db_type):
32
+ if db_type == 0: # Multilingual MiniLM
33
+ embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
34
+ else: # Italian BERT
35
+ embedding = HuggingFaceEmbeddings(model_name="dbmdz/bert-base-italian-xxl-uncased")
36
+
37
  new_client = chromadb.EphemeralClient()
38
  return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
39
 
 
70
  collection_name = collection_name[:-1] + 'Z'
71
  return collection_name
72
 
73
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
74
  list_file_path = [x.name for x in list_file_obj if x is not None]
75
  collection_name = create_collection_name(list_file_path[0])
76
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
77
+ vector_db = create_db(doc_splits, collection_name, db_type)
78
  return vector_db, collection_name, "Completed!"
79
 
80
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
116
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF")
117
 
118
  with gr.Tab("Passo 2 - Elabora Documenti"):
119
+ db_type = gr.Radio(["ChromaDB (Multilingual MiniLM Embedding)", "ChromaDB (Italian BERT Embedding)"], label="Tipo di database vettoriale", value="ChromaDB (Multilingual MiniLM Embedding)", type="index")
120
  with gr.Accordion("Opzioni Avanzate - Divisione del testo del documento", open=False):
121
  slider_chunk_size = gr.Slider(100, 1000, 1000, step=20, label="Dimensione del chunk")
122
  slider_chunk_overlap = gr.Slider(10, 200, 100, step=10, label="Sovrapposizione del chunk")
 
136
  with gr.Tab("Passo 4 - Chatbot"):
137
  chatbot = gr.Chatbot(height=300)
138
  with gr.Accordion("Opzioni avanzate - Riferimenti ai documenti", open=False):
139
+ doc_sources = [gr.Textbox(label=f"Riferimento {i+1}", lines=2, container=True, scale=20) for i in range(5)]
140
+ source_pages = [gr.Number(label="Pagina", scale=1) for _ in range(5)]
141
  msg = gr.Textbox(placeholder="Inserisci il messaggio (es. 'Di cosa tratta questo documento?')", container=True)
142
  submit_btn = gr.Button("Invia messaggio")
143
  clear_btn = gr.Button("Cancella conversazione")
144
 
145
+ db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type], outputs=[vector_db, collection_name, db_progress])
146
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
147
 
148
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)