Law-LM / rag.py
JaiSurya's picture
Changed trust_remote_code=True
850a0c9 verified
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer,util
from transformers import AutoTokenizer , AutoModelForCausalLM
class RAG:
def __init__(self):
self.model_id = "microsoft/Phi-3-mini-128k-instruct"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embedding_model_name = "all-mpnet-base-v2"
self.embeddings_filename = "embeddings.csv"
self.data_pd = pd.read_csv(self.embeddings_filename)
self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
self.data_embeddings = self.get_embeddings()
# Embedding model
self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id)
# LLM
self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id,trust_remote_code=True).to(self.device)
def get_embeddings(self) -> list:
"""Returns the embeddings from the csv file"""
data_embeddings = []
for tensor_str in self.data_pd["embeddings"]:
values_str = tensor_str.split("[")[1].split("]")[0]
values_list = [float(val) for val in values_str.split(",")]
tensor_result = torch.tensor(values_list)
data_embeddings.append(tensor_result)
data_embeddings = torch.stack(data_embeddings).to(self.device)
return data_embeddings
def retrieve_relevant_resource(self,user_query : str , k = 5):
"""Function to retrieve relevant resource"""
query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device)
dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0]
score , idx = torch.topk(dot_score,k=k)
return score,idx
def prompt_formatter(self,query: str, context_items: list[dict]) -> str:
"""
Augments query with text-based context from context_items.
"""
# Join context items into one dotted paragraph
context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
base_prompt = """You are a friendly lawyer chatbot who always responds in the style of a judge
Based on the following context items, please answer the query.
\nNow use the following context items to answer the user query:
{context}
\nRelevant passages: <extract relevant passages from the context here>"""
# Update base prompt with context items and query
base_prompt = base_prompt.format(context=context)
# Create prompt template for instruction-tuned model
dialogue_template = [
{
"role" : "system",
"content" : base_prompt,
},
{
"role": "user",
"content": query,
},
]
# Apply the chat template
prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template,
tokenize=False,
add_generation_prompt=True)
return prompt
def query(self,user_text : str):
scores, indices = self.retrieve_relevant_resource(user_text)
context_items = [self.data_dict[i] for i in indices]
prompt = self.prompt_formatter(query=user_text,context_items=context_items)
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.llm_model.generate(**input_ids,max_new_tokens=512)
output_text = self.tokenizer.decode(outputs[0])
output_text = output_text.split("<|assistant|>")
output_text = output_text[1].split("</s>")[0]
return output_text