import gradio as gr from langchain.tools import Tool from langchain_community.utilities import GoogleSearchAPIWrapper import os from langchain.tools import Tool from langchain_community.utilities import GoogleSearchAPIWrapper def get_search(query:str="", k:int=1): # get the top-k resources with google search = GoogleSearchAPIWrapper(k=k) def search_results(query): return search.results(query, k) tool = Tool( name="Google Search Snippets", description="Search Google for recent results.", func=search_results, ) ref_text = tool.run(query) if 'Result' not in ref_text[0].keys(): return ref_text else: return None from langchain_community.document_transformers import Html2TextTransformer from langchain_community.document_loaders import AsyncHtmlLoader def get_page_content(link:str): loader = AsyncHtmlLoader([link]) docs = loader.load() html2text = Html2TextTransformer() docs_transformed = html2text.transform_documents(docs) if len(docs_transformed) > 0: return docs_transformed[0].page_content else: return None import tiktoken def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens def chunk_text_by_sentence(text, chunk_size=2048): """Chunk the $text into sentences with less than 2k tokens.""" sentences = text.split('. ') chunked_text = [] curr_chunk = [] # 逐句添加文本片段,确保每个段落都小于2k个token for sentence in sentences: if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size: curr_chunk.append(sentence) else: chunked_text.append(". ".join(curr_chunk)) curr_chunk = [sentence] # 添加最后一个片段 if curr_chunk: chunked_text.append(". ".join(curr_chunk)) return chunked_text[0] def chunk_text_front(text, chunk_size = 2048): ''' get the first `trunk_size` token of text ''' chunked_text = "" tokens = num_tokens_from_string(text) if tokens < chunk_size: return text else: ratio = float(chunk_size) / tokens char_num = int(len(text) * ratio) return text[:char_num] def chunk_texts(text, chunk_size = 2048): ''' trunk the text into n parts, return a list of text [text, text, text] ''' tokens = num_tokens_from_string(text) if tokens < chunk_size: return [text] else: texts = [] n = int(tokens/chunk_size) + 1 # 计算每个部分的长度 part_length = len(text) // n # 如果不能整除,则最后一个部分会包含额外的字符 extra = len(text) % n parts = [] start = 0 for i in range(n): # 对于前extra个部分,每个部分多分配一个字符 end = start + part_length + (1 if i < extra else 0) parts.append(text[start:end]) start = end return parts from datetime import datetime from openai import OpenAI import openai import os chatgpt_system_prompt = f''' You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture. Knowledge cutoff: 2023-04 Current date: {datetime.now().strftime('%Y-%m-%d')} ''' def get_draft(question): # Getting the draft answer draft_prompt = ''' IMPORTANT: Try to answer this question/instruction with step-by-step thoughts and make the answer more structural. Use `\n\n` to split the answer into several paragraphs. Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to. ''' # openai_client = OpenAI(api_key=openai.api_key) openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) draft = openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": chatgpt_system_prompt }, { "role": "user", "content": f"{question}" + draft_prompt } ], temperature = 1.0 ).choices[0].message.content return draft def split_draft(draft, split_char = '\n\n'): # 将draft切分为多个段落 # split_char: '\n\n' draft_paragraphs = draft.split(split_char) draft_paragraphs = [d for d in draft_paragraphs if d] # print(f"The draft answer has {len(draft_paragraphs)}") return draft_paragraphs def get_query(question, answer): query_prompt = ''' I want to verify the content correctness of the given question, especially the last sentences. Please summarize the content with the corresponding question. This summarization will be used as a query to search with Bing search engine. The query should be short but need to be specific to promise Bing can find related knowledge or pages. You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data. Try to make the query as relevant as possible to the last few sentences in the content. **IMPORTANT** Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to. ''' # openai_client = OpenAI(api_key = openai.api_key) openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) query = openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": chatgpt_system_prompt }, { "role": "user", "content": f"##Question: {question}\n\n##Content: {answer}\n\n##Instruction: {query_prompt}" } ], temperature = 1.0 ).choices[0].message.content return query def get_content(query): res = get_search(query, 1) if not res: print(">>> No good Google Search Result was found") return None search_results = res[0] link = search_results['link'] # title, snippet res = get_page_content(link) if not res: print(f">>> No content was found in {link}") return None retrieved_text = res trunked_texts = chunk_texts(retrieved_text, 1500) trunked_texts = [trunked_text.replace('\n', " ") for trunked_text in trunked_texts] return trunked_texts def get_revise_answer(question, answer, content): revise_prompt = ''' I want to revise the answer according to retrieved related text of the question in WIKI pages. You need to check whether the answer is correct. If you find some errors in the answer, revise the answer to make it better. If you find some necessary details are ignored, add it to make the answer more plausible according to the related text. If you find the answer is right and do not need to add more details, just output the original answer directly. **IMPORTANT** Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding. Add more details from retrieved text to the answer. Split the paragraphs with `\n\n` characters. Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to. ''' # openai_client = OpenAI(api_key = openai.api_key) openai_client = OpenAI(api_key = os.getenv('OPENAI_API_KEY')) revised_answer = openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=[ { "role": "system", "content": chatgpt_system_prompt }, { "role": "user", "content": f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}" } ], temperature = 1.0 ).choices[0].message.content return revised_answer def get_query_wrapper(q, question, answer): result = get_query(question, answer) q.put(result) # 将结果放入队列 def get_content_wrapper(q, query): result = get_content(query) q.put(result) # 将结果放入队列 def get_revise_answer_wrapper(q, question, answer, content): result = get_revise_answer(question, answer, content) q.put(result) from multiprocessing import Process, Queue def run_with_timeout(func, timeout, *args, **kwargs): q = Queue() # 创建一个Queue对象用于进程间通信 # 创建一个进程来执行传入的函数,将Queue和其他*args、**kwargs作为参数传递 p = Process(target=func, args=(q, *args), kwargs=kwargs) p.start() # 等待进程完成或超时 p.join(timeout) if p.is_alive(): print(f"{datetime.now()} [INFO] 函数{str(func)}执行已超时({timeout}s),正在终止进程...") p.terminate() # 终止进程 p.join() # 确保进程已经终止 result = None # 超时情况下,我们没有结果 else: print(f"{datetime.now()} [INFO] 函数{str(func)}执行成功完成") result = q.get() # 从队列中获取结果 return result from difflib import unified_diff from IPython.display import display, HTML def generate_diff_html(text1, text2): diff = unified_diff(text1.splitlines(keepends=True), text2.splitlines(keepends=True), fromfile='text1', tofile='text2') diff_html = "" for line in diff: if line.startswith('+'): diff_html += f"
{line.rstrip()}
" elif line.startswith('-'): diff_html += f"
{line.rstrip()}
" elif line.startswith('@'): diff_html += f"
{line.rstrip()}
" else: diff_html += f"{line.rstrip()}
" return diff_html newline_char = '\n' def rat(question): print(f"{datetime.now()} [INFO] 生成草稿中...") draft = get_draft(question) print(f"{datetime.now()} [INFO] 获得草稿") # print(f"##################### DRAFT #######################") # print(draft) # print(f"##################### END #######################") print(f"{datetime.now()} [INFO] 处理草稿...") draft_paragraphs = split_draft(draft) print(f"{datetime.now()} [INFO] 草稿被切分为{len(draft_paragraphs)}部分") answer = "" for i, p in enumerate(draft_paragraphs): print(str(i)*80) print(f"{datetime.now()} [INFO] 修改第{i+1}/{len(draft_paragraphs)}部分...") answer = answer + '\n\n' + p # print(f"[{i}/{len(draft_paragraphs)}] Original Answer:\n{answer.replace(newline_char, ' ')}") # query = get_query(question, answer) print(f"{datetime.now()} [INFO] 生成对应Query...") res = run_with_timeout(get_query_wrapper, 3, question, answer) if not res: print(f"{datetime.now()} [INFO] 生成检索词超时,跳过后续步骤...") continue else: query = res print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}") print(f"{datetime.now()} [INFO] 获取网页内容...") # content = get_content(query) res = run_with_timeout(get_content_wrapper, 5, query) if not res: print(f"{datetime.now()} [INFO] 获取网页内容超时,跳过后续步骤...") continue else: content = res for j, c in enumerate(content): if j > 2: break print(f"{datetime.now()} [INFO] 根据网页内容修改对应答案...[{j}/{min(len(content),3)}]") # answer = get_revise_answer(question, answer, c) res = run_with_timeout(get_revise_answer_wrapper, 15, question, answer, c) if not res: print(f"{datetime.now()} [INFO] 修改答案超时,跳过后续步骤...") continue else: diff_html = generate_diff_html(answer, res) display(HTML(diff_html)) answer = res print(f"{datetime.now()} [INFO] 答案修改完成[{j}/{min(len(content),3)}]") # print(f"[{i}/{len(draft_paragraphs)}] REVISED ANSWER:\n {answer.replace(newline_char, ' ')}") # print() return draft, answer # return answer page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation" page_md = """ # RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated. Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning. Feel free to try our demo! """ def clear_func(): return "", "", "" def set_openai_api_key(api_key): if api_key and api_key.startswith("sk-") and len(api_key) > 50: os.environ["OPENAI_API_KEY"] = api_key with gr.Blocks(title = page_title) as demo: gr.Markdown(page_md) with gr.Row(): chatgpt_box = gr.Textbox( label = "ChatGPT", placeholder = "Response from ChatGPT with zero-shot chain-of-thought.", elem_id = "chatgpt" ) with gr.Row(): stream_box = gr.Textbox( label = "Streaming", placeholder = "Interactive response with RAT...", elem_id = "stream", lines = 10, visible = False ) with gr.Row(): rat_box = gr.Textbox( label = "RAT", placeholder = "Final response with RAT ...", elem_id = "rat", lines = 6 ) with gr.Column(elem_id="instruction_row"): with gr.Row(): instruction_box = gr.Textbox( label = "instruction", placeholder = "Enter your instruction here", lines = 2, elem_id="instruction", interactive=True, visible=True ) with gr.Row(): model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo", label='GPT model: ', show_label=True,interactive=True, visible=True) openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter", show_label=False, lines=1, type='password') openai_api_key_textbox.change(set_openai_api_key, inputs=[openai_api_key_textbox], outputs=[]) with gr.Row(): submit_btn = gr.Button( value="submit", visible=True, interactive=True ) clear_btn = gr.Button( value="clear", visible=True, interactive=True ) regenerate_btn = gr.Button( value="regenerate", visible=True, interactive=True ) submit_btn.click( fn = rat, inputs = [instruction_box], outputs = [chatgpt_box, rat_box] ) clear_btn.click( fn = clear_func, inputs = [], outputs = [instruction_box, chatgpt_box, rat_box] ) regenerate_btn.click( fn = rat, inputs = [instruction_box], outputs = [chatgpt_box, rat_box] ) examples = gr.Examples( examples=[ "I went to the supermarket yesterday.", "Helen is a good swimmer."], inputs=[instruction_box] ) demo.launch(server_name="0.0.0.0", debug=True)