|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer,util |
|
from transformers import AutoTokenizer , AutoModelForCausalLM |
|
|
|
|
|
class RAG: |
|
|
|
def __init__(self): |
|
self.model_id = "microsoft/Phi-3-mini-128k-instruct" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.embedding_model_name = "all-mpnet-base-v2" |
|
self.embeddings_filename = "embeddings.csv" |
|
|
|
self.data_pd = pd.read_csv(self.embeddings_filename) |
|
self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records') |
|
|
|
self.data_embeddings = self.get_embeddings() |
|
|
|
|
|
self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id) |
|
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id,trust_remote_code=True).to(self.device) |
|
|
|
def get_embeddings(self) -> list: |
|
"""Returns the embeddings from the csv file""" |
|
data_embeddings = [] |
|
|
|
for tensor_str in self.data_pd["embeddings"]: |
|
values_str = tensor_str.split("[")[1].split("]")[0] |
|
values_list = [float(val) for val in values_str.split(",")] |
|
tensor_result = torch.tensor(values_list) |
|
data_embeddings.append(tensor_result) |
|
|
|
data_embeddings = torch.stack(data_embeddings).to(self.device) |
|
return data_embeddings |
|
|
|
|
|
def retrieve_relevant_resource(self,user_query : str , k = 5): |
|
"""Function to retrieve relevant resource""" |
|
query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device) |
|
dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0] |
|
score , idx = torch.topk(dot_score,k=k) |
|
return score,idx |
|
|
|
def prompt_formatter(self,query: str, context_items: list[dict]) -> str: |
|
""" |
|
Augments query with text-based context from context_items. |
|
""" |
|
|
|
context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items]) |
|
|
|
base_prompt = """You are a friendly lawyer chatbot who always responds in the style of a judge |
|
Based on the following context items, please answer the query. |
|
\nNow use the following context items to answer the user query: |
|
{context} |
|
\nRelevant passages: <extract relevant passages from the context here>""" |
|
|
|
|
|
base_prompt = base_prompt.format(context=context) |
|
|
|
|
|
dialogue_template = [ |
|
{ |
|
"role" : "system", |
|
"content" : base_prompt, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": query, |
|
}, |
|
] |
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template, |
|
tokenize=False, |
|
add_generation_prompt=True) |
|
return prompt |
|
|
|
def query(self,user_text : str): |
|
scores, indices = self.retrieve_relevant_resource(user_text) |
|
context_items = [self.data_dict[i] for i in indices] |
|
prompt = self.prompt_formatter(query=user_text,context_items=context_items) |
|
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
outputs = self.llm_model.generate(**input_ids,max_new_tokens=512) |
|
output_text = self.tokenizer.decode(outputs[0]) |
|
output_text = output_text.split("<|assistant|>") |
|
output_text = output_text[1].split("</s>")[0] |
|
|
|
return output_text |