Spaces:
Sleeping
Sleeping
File size: 5,712 Bytes
f2019a4 23de817 f2019a4 82a9fe8 f2019a4 1e44fa4 846f47f feacba8 f2019a4 31559f1 208476f 31559f1 cc4fb7d 9e80596 f2019a4 9df01d3 9e80596 f2019a4 31559f1 f2019a4 31559f1 208476f f2019a4 459a15e 31559f1 f2019a4 9e80596 cc4fb7d 9e80596 f2019a4 9e80596 f2019a4 9e80596 f2019a4 9e80596 9934680 f2019a4 9e80596 459a15e 9e80596 dcbc7a2 b4ed09d 31559f1 22b51ff 31559f1 22b51ff 9e80596 22b51ff 9e80596 22b51ff 31559f1 22b51ff 9e80596 459a15e 208476f 9e80596 dcbc7a2 9e80596 dcbc7a2 9e80596 208476f 459a15e f2019a4 9e80596 f2019a4 27f6f13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import numpy as np
import pandas as pd
from threading import Thread
from FlagEmbedding import BGEM3FlagModel
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
#Importing the embedding model
embedding_model = BGEM3FlagModel('BAAI/bge-m3',
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
embeddings = np.load("embeddings_albert_tchap.npy")
embeddings_data = pd.read_json("embeddings_albert_tchap.json")
embeddings_text = embeddings_data["text_with_context"].tolist()
#Importing the classifier/router (deberta)
classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta")
tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta")
#Importing the actual generative LLM (llama-based)
model_name = "Pclanglais/Tchap"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to('cuda:0')
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
#Function to guess whether we use the RAG or not.
def classification_chatrag(query):
encoding = tokenizer(query, return_tensors="pt")
encoding = {k: v.to(model_classifier.device) for k,v in encoding.items()}
outputs = model_classifier(**encoding)
logits = outputs.logits
logits.shape
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
# Extract the float value from the tensor
float_value = round(probs.item()*100)
if float_value > 50:
status = True
print("We activate RAG")
else:
status = False
print("We remove RAG")
return status
#Vector search over the database
def vector_search(sentence_query):
query_embedding = embedding_model.encode(sentence_query,
batch_size=12,
max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
# Reshape the query embedding to fit the cosine_similarity function requirements
query_embedding_reshaped = query_embedding.reshape(1, -1)
# Compute cosine similarities
similarities = cosine_similarity(query_embedding_reshaped, embeddings)
# Find the index of the closest document (highest similarity)
closest_doc_index = np.argmax(similarities)
# Closest document's embedding
closest_doc_embedding = embeddings_text[closest_doc_index]
return closest_doc_embedding
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def predict(message, history):
global source_text
global assess_rag
#For now, we only query the vector database once, at the start.
if len(history) == 0:
assess_rag = classification_chatrag(message)
if assess_rag:
source_text = vector_search(message)
history_transformer_format = history + [[message, ""]]
print(history_transformer_format)
stop = StopOnTokens()
messages = []
id_message = 1
total_message = len(history_transformer_format)
for item in history_transformer_format:
#Once we target the ongoing post we add the source.
if id_message == total_message:
if assess_rag:
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
else:
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
else:
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
result = "".join([question, answer])
messages.append(result)
id_message = id_message + 1
messages = "".join(messages)
print(messages)
messages = system_prompt + messages
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=False,
top_p=0.95,
temperature=0.4,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
return messages
# Define the Gradio interface
title = "Tchap"
description = "Le chatbot du service public"
examples = [
[
"Qui peut bénéficier de l'AIP?", # user_message
0.7 # temperature
]
]
gr.ChatInterface(predict).launch() |