import re from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import numpy as np import pandas as pd from threading import Thread from FlagEmbedding import BGEM3FlagModel from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoModelForSequenceClassification device = "cuda" if torch.cuda.is_available() else "cpu" #Importing the embedding model embedding_model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation embeddings = np.load("embeddings_albert_tchap.npy") embeddings_data = pd.read_json("embeddings_albert_tchap.json") embeddings_text = embeddings_data["text_with_context"].tolist() #Importing the classifier/router (deberta) classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta") classifier_tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta") #Importing the actual generative LLM (llama-based) model_name = "Pclanglais/Tchap" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) model = model.to('cuda:0') system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>" source_text = "Les sources utilisées par Albert-Tchap vont apparaître ici'" #Function to guess whether we use the RAG or not. def classification_chatrag(query): print(query) encoding = classifier_tokenizer(query, return_tensors="pt") encoding = {k: v.to(classifier_model.device) for k,v in encoding.items()} outputs = classifier_model(**encoding) logits = outputs.logits logits.shape # apply sigmoid + threshold sigmoid = torch.nn.Sigmoid() probs = sigmoid(logits.squeeze().cpu()) predictions = np.zeros(probs.shape) # Extract the float value from the tensor float_value = round(probs.item()*100) print(float_value) if float_value > 50: status = True print("We activate RAG") else: status = False print("We remove RAG") return status #Vector search over the database def vector_search(sentence_query): query_embedding = embedding_model.encode(sentence_query, batch_size=12, max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process. )['dense_vecs'] # Reshape the query embedding to fit the cosine_similarity function requirements query_embedding_reshaped = query_embedding.reshape(1, -1) # Compute cosine similarities similarities = cosine_similarity(query_embedding_reshaped, embeddings) # Find the index of the closest document (highest similarity) closest_doc_index = np.argmax(similarities) # Closest document's embedding closest_doc_embedding = embeddings_text[closest_doc_index] return closest_doc_embedding class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [29, 0] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def predict(message, history): global source_text global assess_rag #For now, we only query the vector database once, at the start. if len(history) == 0: assess_rag = classification_chatrag(message) if assess_rag: source_text = vector_search(message) else: source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir." history_transformer_format = history + [[message, ""]] print(history_transformer_format) stop = StopOnTokens() messages = [] id_message = 1 total_message = len(history_transformer_format) for item in history_transformer_format: #Once we target the ongoing post we add the source. if id_message == total_message: if assess_rag: question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text else: question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] else: question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1] result = "".join([question, answer]) messages.append(result) id_message = id_message + 1 messages = "".join(messages) print(messages) messages = system_prompt + messages model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=False, top_p=0.95, temperature=0.4, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != '<': partial_message += new_token yield partial_message return messages, source_text # Define the Gradio interface title = "Tchap" description = "Le chatbot du service public" examples = [ [ "Qui peut bénéficier de l'AIP?", # user_message 0.7 # temperature ] ] # Define the Gradio interface with gr.Blocks() as demo: with gr.Row(): chat_box = gr.ChatInterface(fn=predict, inputs=[gr.Text(), gr.Dataframe()]) source_display = gr.HTML() chat_box.change(source_display.update) # Use the change event to update the HTML component if __name__ == "__main__": demo.queue().launch()