junyi_bot_external / utils /work_flow_controller.py
ChenyuRabbitLove's picture
add upload feature and optimize user experience
a2f42ca
raw
history blame
No virus
5.02 kB
import os
import json
import logging
import hashlib
import pandas as pd
from .gpt_processor import (EmbeddingGenerator, KeywordsGenerator, Summarizer,
TopicsGenerator, Translator)
from .pdf_processor import PDFProcessor
processors = {
'pdf': PDFProcessor,
}
class WorkFlowController():
def __init__(self, file_src) -> None:
# check if the file_path is list
# self.file_paths = self.__get_file_name(file_src)
self.file_paths = [x.name for x in file_src]
print(self.file_paths)
self.files_info = {}
for file_path in self.file_paths:
file_name = file_path.split('/')[-1]
file_format = file_path.split('.')[-1]
self.file_processor = processors[file_format]
file = self.file_processor(file_path).file_info
file = self.__process_file(file)
self.files_info[file_name] = file
self.__dump_to_json()
self.__dump_to_csv()
def __get_summary(self, file: dict):
# get summary from file content
summarizer = Summarizer()
file['summarized_content'] = summarizer.summarize(file['file_full_content'])
return file
def __get_keywords(self, file: dict):
# get keywords from file content
keywords_generator = KeywordsGenerator()
file['keywords'] = keywords_generator.extract_keywords(file['file_full_content'])
return file
def __get_topics(self, file: dict):
# get topics from file content
topics_generator = TopicsGenerator()
file['topics'] = topics_generator.extract_topics(file['file_full_content'])
return file
def __get_embedding(self, file):
# get embedding from file content
# return embedding
embedding_generator = EmbeddingGenerator()
for i, _ in enumerate(file['file_content']):
# use i+1 to meet the index of file_content
file['file_content'][i+1]['page_embedding'] = embedding_generator.get_embedding(file['file_content'][i+1]['page_content'])
return file
def __translate_to_chinese(self, file: dict):
# translate file content to chinese
translator = Translator()
# reset the file full content
file['file_full_content'] = ''
for i, _ in enumerate(file['file_content']):
# use i+1 to meet the index of file_content
file['file_content'][i+1]['page_content'] = translator.translate_to_chinese(file['file_content'][i+1]['page_content'])
file['file_full_content'] = file['file_full_content'] + file['file_content'][i+1]['page_content']
return file
def __process_file(self, file: dict):
# process file content
# return processed data
if not file['is_chinese']:
file = self.__translate_to_chinese(file)
file = self.__get_embedding(file)
file = self.__get_summary(file)
# file = self.__get_keywords(file)
# file = self.__get_topics(file)
return file
def __dump_to_json(self):
with open(os.path.join(os.getcwd(), 'knowledge_base.json'), 'w', encoding='utf-8') as f:
print("Dumping to json, the path is: " + os.path.join(os.getcwd(), 'knowledge_base.json'))
self.result_path = os.path.join(os.getcwd(), 'knowledge_base.json')
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
def __construct_knowledge_base_dataframe(self):
rows = []
for file_path, content in self.files_info.items():
file_full_content = content["file_full_content"]
for page_num, page_details in content["file_content"].items():
row = {
"file_name": content["file_name"],
"page_num": page_details["page_num"],
"page_content": page_details["page_content"],
"page_embedding": page_details["page_embedding"],
"file_full_content": file_full_content,
}
rows.append(row)
columns = ["file_name", "page_num", "page_content", "page_embedding", "file_full_content"]
df = pd.DataFrame(rows, columns=columns)
return df
def __dump_to_csv(self):
df = self.__construct_knowledge_base_dataframe()
df.to_csv(os.path.join(os.getcwd(), 'knowledge_base.csv'), index=False)
print("Dumping to csv, the path is: " + os.path.join(os.getcwd(), 'knowledge_base.csv'))
self.csv_result_path = os.path.join(os.getcwd(), 'knowledge_base.csv')
def __get_file_name(self, file_src):
file_paths = [x.name for x in file_src]
file_paths.sort(key=lambda x: os.path.basename(x))
md5_hash = hashlib.md5()
for file_path in file_paths:
with open(file_path, "rb") as f:
while chunk := f.read(8192):
md5_hash.update(chunk)
return md5_hash.hexdigest()