junyi_bot_external / utils /work_flow_controller.py
ChenyuRabbitLove's picture
feat/formatter
66b707b
raw
history blame
No virus
5.49 kB
import os
import json
import logging
import hashlib
import pandas as pd
from .gpt_processor import (
EmbeddingGenerator,
KeywordsGenerator,
Summarizer,
TopicsGenerator,
Translator,
)
from .pdf_processor import PDFProcessor
processors = {
"pdf": PDFProcessor,
}
class WorkFlowController:
def __init__(self, file_src, uid) -> None:
# check if the file_path is list
# self.file_paths = self.__get_file_name(file_src)
self.file_paths = [x.name for x in file_src]
self.uid = uid
print(self.file_paths)
self.files_info = {}
for file_path in self.file_paths:
file_name = file_path.split("/")[-1]
file_format = file_path.split(".")[-1]
self.file_processor = processors[file_format]
file = self.file_processor(file_path).file_info
file = self.__process_file(file)
self.files_info[file_name] = file
self.__dump_to_json()
self.__dump_to_csv()
def __get_summary(self, file: dict):
# get summary from file content
summarizer = Summarizer()
file["summarized_content"] = summarizer.summarize(file["file_full_content"])
return file
def __get_keywords(self, file: dict):
# get keywords from file content
keywords_generator = KeywordsGenerator()
file["keywords"] = keywords_generator.extract_keywords(
file["file_full_content"]
)
return file
def __get_topics(self, file: dict):
# get topics from file content
topics_generator = TopicsGenerator()
file["topics"] = topics_generator.extract_topics(file["file_full_content"])
return file
def __get_embedding(self, file):
# get embedding from file content
# return embedding
embedding_generator = EmbeddingGenerator()
for i, _ in enumerate(file["file_content"]):
# use i+1 to meet the index of file_content
file["file_content"][i + 1][
"page_embedding"
] = embedding_generator.get_embedding(
file["file_content"][i + 1]["page_content"]
)
return file
def __translate_to_chinese(self, file: dict):
# translate file content to chinese
translator = Translator()
# reset the file full content
file["file_full_content"] = ""
for i, _ in enumerate(file["file_content"]):
# use i+1 to meet the index of file_content
print("Translating page: " + str(i + 1))
file["file_content"][i + 1][
"page_content"
] = translator.translate_to_chinese(
file["file_content"][i + 1]["page_content"]
)
file["file_full_content"] = (
file["file_full_content"] + file["file_content"][i + 1]["page_content"]
)
return file
def __process_file(self, file: dict):
# process file content
# return processed data
if not file["is_chinese"]:
print("Translating to chinese...")
file = self.__translate_to_chinese(file)
print("Getting embedding...")
file = self.__get_embedding(file)
print("Getting summary...")
file = self.__get_summary(file)
return file
def __dump_to_json(self):
with open(
os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json"),
"w",
encoding="utf-8",
) as f:
print(
"Dumping to json, the path is: "
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
)
self.json_result_path = os.path.join(
os.getcwd(), f"{self.uid}_knowledge_base.json"
)
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
def __construct_knowledge_base_dataframe(self):
rows = []
for file_path, content in self.files_info.items():
for page_num, page_details in content["file_content"].items():
row = {
"file_name": content["file_name"],
"page_num": page_details["page_num"],
"page_content": page_details["page_content"],
"page_embedding": page_details["page_embedding"],
}
rows.append(row)
columns = [
"file_name",
"page_num",
"page_content",
"page_embedding",
]
df = pd.DataFrame(rows, columns=columns)
return df
def __dump_to_csv(self):
df = self.__construct_knowledge_base_dataframe()
df.to_csv(
os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv"), index=False
)
print(
"Dumping to csv, the path is: "
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
)
self.csv_result_path = os.path.join(
os.getcwd(), f"{self.uid}_knowledge_base.csv"
)
def __get_file_name(self, file_src):
file_paths = [x.name for x in file_src]
file_paths.sort(key=lambda x: os.path.basename(x))
md5_hash = hashlib.md5()
for file_path in file_paths:
with open(file_path, "rb") as f:
while chunk := f.read(8192):
md5_hash.update(chunk)
return md5_hash.hexdigest()