Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from datasets import load_dataset | |
from fastapi.middleware.cors import CORSMiddleware | |
import pdfplumber | |
import pytesseract | |
# Loading | |
import os | |
import zipfile | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
import torch | |
import json | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
import logging | |
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) | |
logging.getLogger("haystack").setLevel(logging.INFO) | |
document_store = QdrantDocumentStore( | |
path="database", | |
recreate_index=True, | |
use_sparse_embeddings=True, | |
embedding_dim = 384 | |
) | |
def extract_zip(zip_path, target_folder): | |
""" | |
Extracts all files from a ZIP archive and returns a list of their paths. | |
Args: | |
zip_path (str): Path to the ZIP file. | |
target_folder (str): Folder where the files will be extracted. | |
Returns: | |
List[str]: List of extracted file paths. | |
""" | |
extracted_files = [] | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(target_folder) | |
for filename in zip_ref.namelist(): | |
extracted_files.append(os.path.join(target_folder, filename)) | |
return extracted_files | |
def extract_text_from_pdf(pdf_path): | |
with pdfplumber.open(pdf_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
def extract_ocr_text_from_pdf(pdf_path): | |
from pdf2image import convert_from_path | |
images = convert_from_path(pdf_path) | |
text= "" | |
for image in images: | |
text += pytesseract.image_to_string(image,lang='vie') | |
return text | |
async def create_upload_file(text_field: str, file: UploadFile = File(...), ocr:bool=False): | |
# Imports | |
import time | |
from haystack import Document, Pipeline | |
from haystack.components.writers import DocumentWriter | |
from haystack.components.preprocessors import DocumentSplitter, DocumentCleaner | |
from haystack.components.joiners import DocumentJoiner | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack.document_stores.types import DuplicatePolicy | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedDocumentEmbedder, | |
FastembedSparseTextEmbedder, | |
FastembedSparseDocumentEmbedder | |
) | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
documents=[] | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
with open(file_savePath) as fd: | |
for line in fd: | |
obj = json.loads(line) | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
elif '.zip' in file_savePath: | |
extracted_files_list = extract_zip(file_savePath, temp_path) | |
print("Extracted files:") | |
for file_path in extracted_files_list: | |
if '.pdf' in file_path: | |
if ocr: | |
text = extract_ocr_text_from_pdf(file_path) | |
else: | |
text = extract_text_from_pdf(file_path) | |
obj = {text_field:text,file_path:file_path} | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Indexing | |
indexing = Pipeline() | |
document_joiner = DocumentJoiner() | |
document_cleaner = DocumentCleaner() | |
document_splitter = DocumentSplitter(split_by="word", split_length=1000, split_overlap=0) | |
indexing.add_component("document_joiner", document_joiner) | |
indexing.add_component("document_cleaner", document_cleaner) | |
indexing.add_component("document_splitter", document_splitter) | |
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions")) | |
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")) | |
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE)) | |
indexing.connect("document_joiner", "document_cleaner") | |
indexing.connect("document_cleaner", "document_splitter") | |
indexing.connect("document_splitter", "sparse_doc_embedder") | |
indexing.connect("sparse_doc_embedder", "dense_doc_embedder") | |
indexing.connect("dense_doc_embedder", "writer") | |
indexing.run({"document_joiner": {"documents": documents}}) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
def search(prompt: str): | |
import time | |
from haystack import Document, Pipeline | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedSparseTextEmbedder | |
) | |
from haystack.components.rankers import TransformersSimilarityRanker | |
from haystack.components.joiners import DocumentJoiner | |
from haystack.components.generators import OpenAIGenerator | |
from haystack.utils import Secret | |
from haystack.components.builders import PromptBuilder | |
start_time = time.time() | |
# Querying | |
template = """ | |
Given the following information, answer the question. | |
Context: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
Question: {{question}} | |
Answer: | |
""" | |
prompt_builder = PromptBuilder(template=template) | |
generator = OpenAIGenerator( | |
api_key=Secret.from_env_var("OCTOAI_TOKEN"), | |
api_base_url="https://text.octoai.run/v1", | |
model="meta-llama-3-70b-instruct", | |
generation_kwargs = {"max_tokens": 512} | |
) | |
querying = Pipeline() | |
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions")) | |
querying.add_component("dense_text_embedder", FastembedTextEmbedder( | |
model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", prefix="Represent this sentence for searching relevant passages: ") | |
) | |
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store)) | |
querying.add_component("document_joiner", DocumentJoiner()) | |
querying.add_component("ranker", TransformersSimilarityRanker(model="BAAI/bge-m3")) | |
querying.add_component("prompt_builder", prompt_builder) | |
querying.add_component("llm", generator) | |
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding") | |
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding") | |
querying.connect("retriever", "document_joiner") | |
querying.connect("document_joiner", "ranker") | |
querying.connect("ranker.documents", "prompt_builder.documents") | |
querying.connect("prompt_builder", "llm") | |
querying.debug=True | |
results = querying.run( | |
{"dense_text_embedder": {"text": prompt}, | |
"sparse_text_embedder": {"text": prompt}, | |
"ranker": {"query": prompt}, | |
"prompt_builder": {"question": prompt} | |
} | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
return results | |
async def download_database(): | |
import time | |
start_time = time.time() | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
async def convert_upload_file(file: UploadFile = File(...)): | |
import pytesseract | |
from pdf2image import convert_from_path | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# convert PDF to image | |
images = convert_from_path(file_savePath) | |
text="" | |
# Extract text from images | |
for image in images: | |
ocr_text = pytesseract.image_to_string(image,lang='vie') | |
text=text+ocr_text+'\n' | |
return text | |
def get_type_name(element): | |
return type(element).__name__ | |
def filter_by_type(elements, type): | |
return [element for element in elements if get_type_name(element) == type] | |
import re | |
def extract_value_from_text(text, format): | |
pattern = re.compile(format) | |
match = pattern.search(text) | |
if match: | |
return match.group(0) # Use group(0) to get the entire match | |
else: | |
return None | |
def filter_by_labels(elements, labels, format): | |
for element in elements: | |
for label in labels: | |
if label.lower() in element.text.lower(): | |
return extract_value_from_text(element.text, format) | |
return None | |
def filter_by_values(elements, values): | |
for element in elements: | |
for value in values: | |
if value.lower() in element.text.lower(): | |
return value | |
return None | |
def get_elements_by_schemas(elements, schemas): | |
result_elements=[] | |
for schema in schemas: | |
result_element={} | |
filterred_by_type_elements = filter_by_type(elements, schema['layout_type']) | |
if 'labels' in schema: | |
filterred_by_label_elements = filter_by_labels(filterred_by_type_elements, schema['labels'], schema['format']) | |
if filterred_by_label_elements is not None: | |
result_element[schema['name']] = filterred_by_label_elements | |
result_elements.append(result_element) | |
elif 'values' in schema: | |
fitered_by_value_elements = filter_by_values(filterred_by_type_elements, schema['values']) | |
if fitered_by_value_elements is not None: | |
result_element[schema['name']] = fitered_by_value_elements | |
result_elements.append(result_element) | |
else: | |
if filterred_by_type_elements is not None: | |
result_element[schema['name']] = filterred_by_type_elements[0].text | |
result_elements.append(result_element) | |
return result_elements | |
async def extract_upload_file(file: UploadFile = File(...)): | |
from unstructured.partition.pdf import partition_pdf | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# Returns a List[Element] present in the pages of the parsed pdf document | |
elements = partition_pdf(file_savePath, languages=["vie"]) | |
schema = [{'name':'publisher','layout_type':'Title','position':0,'from_last':False},{'name':'number','layout_type':'Text','position':0,'from_last':False, 'label':['Số','Luật số']}] | |
return get_elements_by_schemas(elements, schemas) | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |