Spaces:
Sleeping
Sleeping
import os | |
import PyPDF2 | |
import pandas as pd | |
import warnings | |
import re | |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer | |
import torch | |
import gradio as gr | |
from typing import Union | |
from datasets import Dataset | |
warnings.filterwarnings("ignore") | |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") | |
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") | |
def process_pdfs(parent_dir: Union[str,list]): | |
""" processess the PDF files and returns a dataframe with the text of each page in a | |
different line""" # XD | |
# creating a pdf file object | |
df = pd.DataFrame(columns = ["title","text"]) | |
if type(parent_dir) == str : | |
parent_dir = [parent_dir] | |
for file_path in parent_dir: | |
if ".pdf" not in file_path : # skip non pdf files | |
raise Exception("only pdf files are supported") | |
# creating a pdf file object | |
pdfFileObj = open(file_path, 'rb') | |
# creating a pdf reader object | |
pdfReader = PyPDF2.PdfReader(pdfFileObj) | |
# printing number of pages in pdf file | |
num_pages = len(pdfReader.pages) | |
for i in range(num_pages) : | |
pageObj = pdfReader.pages[i] | |
# extracting text from page | |
txt = pageObj.extract_text() | |
txt = txt.replace("\n","") # strip return to line | |
txt = txt.replace("\t","") # strip tabs | |
txt = re.sub(r" +"," ",txt) # strip extra space | |
# 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model | |
file_name = file_path.split("/")[-1] | |
if len(txt) < 512 : | |
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt]],columns=["title","text"]) | |
df = pd.concat([df,new_data],ignore_index=True) | |
else : | |
while len(txt) > 512 : | |
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt[:512]]],columns=["title","text"]) | |
df = pd.concat([df,new_data],ignore_index=True) | |
txt = txt[512:] | |
# closing the pdf file object | |
pdfFileObj.close() | |
return df | |
def process(example): | |
"""process the bathces of the dataset and returns the embeddings""" | |
try : | |
tokens = ctx_tokenizer(example["text"], return_tensors="pt") | |
embed = ctx_encoder(**tokens)[0][0].detach().numpy() | |
return {'embeddings': embed} | |
except Exception as e: | |
raise Exception(f"error in process: {e}") | |
def process_dataset(df): | |
"""processess the dataframe and returns a dataset variable""" | |
if len(df) == 0 : | |
raise Exception("empty pdf files, or can't read text from them") | |
ds = Dataset.from_pandas(df) | |
ds = ds.map(process) | |
ds.add_faiss_index(column='embeddings') # add faiss index | |
return ds | |
def search(query, ds, k=3): | |
"""searches the query in the dataset and returns the k most similar""" | |
try : | |
tokens = q_tokenizer(query, return_tensors="pt") | |
query_embed = q_encoder(**tokens)[0][0].detach().numpy() | |
scores, retrieved_examples = ds.get_nearest_examples("embeddings", query_embed, k=k) | |
out = f"""**title** : {retrieved_examples["title"][0]},\ncontent: {retrieved_examples["text"][0]}\n\n\n**similar resources:** {retrieved_examples["title"]} | |
""" | |
except Exception as e: | |
out = f"error in search: {e}" | |
return out | |
def predict(query,file_paths, k=3): | |
"""predicts the most similar files to the query""" | |
try : | |
df = process_pdfs(file_paths) | |
ds = process_dataset(df) | |
out = search(query,ds,k=k) | |
except Exception as e: | |
out = f"error in predict: {e}" | |
return out | |
with gr.Blocks() as demo : | |
gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>") | |
with gr.Row(): | |
with gr.Column(): | |
files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple") | |
query = gr.Text(label="query") | |
with gr.Accordion("number of references",open=False): | |
k = gr.Number(value=3,show_label=False,precision=0,minimum=1,container=False) | |
button = gr.Button("search") | |
with gr.Column(): | |
output = gr.Markdown(label="output") | |
button.click(predict, [query,files,k],outputs=output) | |
demo.launch() | |