File size: 8,136 Bytes
3ec9224
5be8df6
ebc9208
 
 
 
 
ab26ada
93068c0
ebc9208
 
 
 
 
93068c0
 
 
 
 
ebc9208
 
 
 
 
 
 
 
93068c0
 
ab26ada
55c700d
 
 
 
 
 
ebc9208
589aec7
 
ebc9208
 
93068c0
 
 
 
 
ebc9208
93068c0
 
 
 
 
ebc9208
 
 
 
 
 
 
ab26ada
ebc9208
 
93068c0
 
ebc9208
93068c0
ebc9208
 
 
 
 
 
55c700d
ebc9208
 
 
55c700d
93068c0
ebc9208
 
 
 
93068c0
ebc9208
 
93068c0
ebc9208
 
 
 
93068c0
ebc9208
70f7419
ebc9208
70f7419
 
 
 
 
 
 
 
ebc9208
ab0e4f2
 
 
ebc9208
 
 
 
 
 
7f0656e
ebc9208
7f0656e
 
ebc9208
7f0656e
55c700d
7f0656e
2ed44bd
 
7f0656e
 
ebc9208
7f0656e
fabd344
7f0656e
 
 
 
 
 
 
12b47b9
7f0656e
ebc9208
7f0656e
55c700d
 
7f0656e
 
12b47b9
ebc9208
55c700d
93068c0
12b47b9
93068c0
7060329
95812af
ab0e4f2
ebc9208
 
 
ab26ada
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings 
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re

list_llm = [
    "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "google/gemma-7b-it", "google/gemma-2b-it",
    "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_documents(pages)

def create_db(splits, collection_name, db_type):
    if db_type == 0:  # Multilingual MiniLM
        embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
    else:  # Italian BERT
        embedding = HuggingFaceEmbeddings(model_name="dbmdz/bert-base-italian-xxl-uncased")
    
    new_client = chromadb.EphemeralClient()
    return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)

def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    progress(0.5, desc="Initializing HF Hub...")
    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
    )
    
    memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
    retriever = vector_db.as_retriever()
    
    return ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )

def create_collection_name(filepath):
    collection_name = Path(filepath).stem
    collection_name = unidecode(collection_name.replace(" ", "-"))
    collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
    if len(collection_name) < 3:
        collection_name += 'xyz'
    if not collection_name[0].isalnum():
        collection_name = 'A' + collection_name[1:]
    if not collection_name[-1].isalnum():
        collection_name = collection_name[:-1] + 'Z'
    return collection_name

def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    collection_name = create_collection_name(list_file_path[0])
    doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
    vector_db = create_db(doc_splits, collection_name, db_type)
    return vector_db, collection_name, "Completed!"

def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
    return qa_chain, "Completed!"

def format_chat_history(message, chat_history):
    return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]

def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"].split("Helpful Answer:")[-1]
    response_sources = response["source_documents"]
    sources = [(source.page_content.strip(), source.metadata["page"] + 1) for source in response_sources[:5]]
    new_history = history + [(message, response_answer)]
    
    # Ensure we always return 5 sources and 5 pages
    source_texts = [source[0] for source in sources] + [''] * (5 - len(sources))
    source_pages = [source[1] for source in sources] + [0] * (5 - len(sources))
    
    return (qa_chain, gr.update(value=""), new_history, 
            *source_texts[:5],  # Unpack exactly 5 source texts
            *source_pages[:5])  # Unpack exactly 5 source pages

def clear_conversation():
    return gr.update(value=""), [], "", "", "", "", "", 0, 0, 0, 0, 0

def demo():
    with gr.Blocks(theme="base") as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        collection_name = gr.State()
        
        gr.Markdown("# Creatore di Chatbot basato su PDF")
        
        with gr.Tab("Passo 1 - Carica PDF"):
            document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF")
        
        with gr.Tab("Passo 2 - Elabora Documenti"):
            db_type = gr.Radio(["ChromaDB (Multilingual MiniLM Embedding)", "ChromaDB (Italian BERT Embedding)"], label="Tipo di database vettoriale", value="ChromaDB (Multilingual MiniLM Embedding)", type="index")
            with gr.Accordion("Opzioni Avanzate - Divisione del testo del documento", open=False):
                slider_chunk_size = gr.Slider(100, 1000, 1000, step=20, label="Dimensione del chunk")
                slider_chunk_overlap = gr.Slider(10, 200, 100, step=10, label="Sovrapposizione del chunk")
            db_progress = gr.Textbox(label="Inizializzazione del database vettoriale", value="Nessuna")
            db_btn = gr.Button("Genera database vettoriale")
            
        with gr.Tab("Passo 3 - Inizializza catena QA"):
            llm_btn = gr.Radio(list_llm_simple, label="Modelli LLM", value=list_llm_simple[4], type="index")
            with gr.Accordion("Opzioni avanzate - Modello LLM", open=False):
                slider_temperature = gr.Slider(0.01, 1.0, 0.3, step=0.1, label="Temperatura")
                slider_maxtokens = gr.Slider(224, 4096, 1024, step=32, label="Token massimi")
                slider_topk = gr.Slider(1, 10, 3, step=1, label="Campioni top-k")
            language_btn = gr.Radio(["Italiano", "Inglese"], label="Lingua", value="Italiano", type="index")
            llm_progress = gr.Textbox(value="Nessuna", label="Inizializzazione catena QA")
            qachain_btn = gr.Button("Inizializza catena di Domanda e Risposta")
            
        with gr.Tab("Passo 4 - Chatbot"):
            chatbot = gr.Chatbot(height=300)
            with gr.Accordion("Opzioni avanzate - Riferimenti ai documenti", open=False):
                doc_sources = [gr.Textbox(label=f"Riferimento {i+1}", lines=2, container=True, scale=20) for i in range(5)]
                source_pages = [gr.Number(label="Pagina", scale=1) for _ in range(5)]
            msg = gr.Textbox(placeholder="Inserisci il messaggio (es. 'Di cosa tratta questo documento?')", container=True)
            submit_btn = gr.Button("Invia messaggio")
            clear_btn = gr.Button("Cancella conversazione")
           
        db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type], outputs=[vector_db, collection_name, db_progress])
        qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])

        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
        msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
        clear_btn.click(clear_conversation, inputs=[], outputs=[chatbot] + doc_sources + source_pages)
    
    demo.queue().launch(debug=True)

if __name__ == "__main__":
    demo()