arabellastrange's picture
fixed git
0e9148c
raw
history blame
11.8 kB
import json
import linecache
import logging
import os
import re
import traceback
import tracemalloc
import nltk
import openai
# from tenacity import (retry,stop_after_attempt,stop_after_delay, wait_random_exponential)
from tenacity import *
# from agents.utils import read_file
logger = logging.getLogger("agent_logger")
openai.api_key = os.getenv('gpt_api_key')
# paid and ad free
google_key = os.getenv("google_search_api_key")
# cx: The identifier of the Programmable Search Engine.
google_cx = os.getenv("google_cx_api_key")
GOOGLE = "google"
USER = "user"
ASSISTANT = "assistant"
MODEL = "gpt-3.5-turbo"
sites = {} # initialize dictionary or sites used
new_sites = {} # initialize dictionary or sites used
try:
with open("sites", "r") as f:
sites = json.loads(f.read())
except:
print("Failed to read sites.")
# for experimenting with Vicuna
def display_top(snapshot, key_type="lineno", limit=10):
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
)
)
top_stats = snapshot.statistics(key_type)
logger.info("Top %s lines" % limit)
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
logger.info(
"#%s: %s:%s: %.1f KiB"
% (index, frame.filename, frame.lineno, stat.size / 1024)
)
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
logger.info(" %s" % line)
other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
logger.info("%s other: %.1f KiB" % (len(other), size / 1024))
total = sum(stat.size for stat in top_stats)
logger.info("Total allocated size: %.1f KiB" % (total / 1024))
class turn:
def __init__(self, role="assistant", message="", tldr="", source="", keywords=[]):
self.role = role
self.message = message
self.tldr = tldr
self.source = source
self.keywords = keywords
def __str__(self):
s = ""
if self.role is not None and len(self.role) > 0:
s = s + "r: " + self.role
if self.message is not None and len(self.message) > 0:
s = s + " m: " + self.message
if self.source is not None and len(self.source) > 0:
s = s + " s: " + self.source
if self.tldr is not None and len(self.tldr) > 0:
s = s + "tldr: " + self.tldr
return s
def is_google_turn(self):
return self.source is not None and self.source == GOOGLE
def is_user_turn(self):
return self.source is not None and self.source == USER
def is_assistant_turn(self):
return self.source is not None and self.source == ASSISTANT
# @retry(wait=wait_random_exponential(min=1, max=2), stop=(stop_after_delay(15) | stop_after_attempt(2)))
def chatCompletion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def ask_gpt(model, gpt_message, max_tokens, temp, top_p):
completion = None
try:
completion = openai.chat.completions.create(
model=model,
messages=gpt_message,
max_tokens=max_tokens,
temperature=temp,
top_p=top_p,
)
except:
traceback.print_exc()
if completion is not None:
response = completion.choices[0].message.content.lstrip(" ,:.")
logger.info(response)
return response
else:
logger.info("no response")
return None
def ask_gpt_with_retries(model, gpt_message, tokens, temp, timeout, tries):
retryer = Retrying(stop=(stop_after_delay(timeout) | stop_after_attempt(1)))
r = retryer(
ask_gpt,
model=model,
gpt_message=gpt_message,
max_tokens=tokens,
temp=temp,
top_p=1,
)
return r
INFORMATION_QUERY = "information query"
INTENTS = []
def find_intent(response):
global INTENTS, INFORMATION_QUERY
for intent in INTENTS:
if intent in response.lower():
return intent
return INFORMATION_QUERY
def find_query(response):
search_query_phrase = response
phrase_index = response.lower().find("phrase:")
quoted_strings = []
if phrase_index < 0:
phrase_index = 0
else:
phrase_index += len("phrase:")
quoted_strings = re.findall(r'"([^"]*)"', search_query_phrase[phrase_index:])
if len(quoted_strings) == 0:
quoted_strings = re.findall(r"'([^']*)'", search_query_phrase[phrase_index:])
if len(quoted_strings) > 0:
# logger.info(quoted_strings)
phrase = quoted_strings[0]
return phrase, response[response.find(phrase) + len(phrase) + 1:]
else:
logger.info("no quoted text, returning original query string", response)
# logger.info(response)
return "", response
def find_keywords(response, query_phrase, orig_phrase):
# keywords includes those suggested by gpt and any remaining words from query phrase len > 4
keywords = []
quoted_strings = re.findall(r'"([^"]*)"', query_phrase)
quoted_strings2 = re.findall(r'"([^"]*)"', orig_phrase)
remainder = query_phrase
k_index = response.lower().find("keyword")
if k_index > 0:
keyword_string = response[k_index + len("keyword"):]
nm_index = keyword_string.find("Named-Entities:")
if nm_index > 0:
keyword_string = keyword_string[:nm_index].rstrip()
# logger.info(keyword_string)
c_index = keyword_string.find(":")
keyword_string = keyword_string[c_index + 1:]
candidates = keyword_string.split(",")
for keyword in candidates:
keyword = keyword.strip(":,.\t\n").lstrip(" ")
if len(keyword) > 3 or keyword[0:1].isupper():
keywords.append(keyword)
return keywords
return ""
# don't know why this compilation error doesn't throw errors at runtime, but it doesn't, on the other hand trying to
# fix this creates an infinite import loop, so don't touch this.
def split_interaction(interaction):
qs = interaction.find(prefix)
rs = interaction.find(suffix)
if qs >= 0 and rs >= 0:
query = interaction[len(prefix): rs].lstrip()
response = interaction[rs + len(suffix):].lstrip()
return query, response
else:
logger.info("can't parse", interaction)
return "", ""
def findnth(haystack, needle, n):
parts = haystack.split(needle, n + 1)
if len(parts) <= n + 1:
return -1
return len(haystack) - len(parts[-1]) - len(needle)
def extract_site(url):
site = ""
base = findnth(url, "/", 2)
if base > 2:
site = url[:base].split(".")
if len(site) > 1:
site = site[-2]
site = site.replace("https://", "")
site = site.replace("http://", "")
return site
def extract_domain(url):
site = ""
base = findnth(url, "/", 2)
if base > 2:
domain = url[:base].split(".")
if len(domain) > 1:
domain = domain[-2] + "." + domain[-1]
domain = domain.replace("https://", "")
domain = domain.replace("http://", "")
return domain
def part_of_keyword(word, keywords):
for keyword in keywords:
if word in keyword:
return True
return False
keyword_prompt = 'Perform two tasks on the following text. First, rewrite the <text> as an effective google search phrase. Second, analyze text and list keywords and named-entities found. Return the result as: Phrase: "<google search phrase>"\nKeywords: <list of keywords>\nNamed-Entities: <list of Named-Entities>'
def get_search_phrase_and_keywords(query_string, chat_history):
gpt_message = [
{"role": "user", "content": keyword_prompt},
{"role": "user", "content": "Text\n" + query_string},
{"role": "assistant", "content": "Phrase:"},
]
response_text = ""
completion = None
# for role in gpt_message:
# logger.info(role)
# logger.info()
response_text = ask_gpt_with_retries(
"gpt-3.5-turbo", gpt_message, tokens=150, temp=0.3, timeout=6, tries=2
)
logger.info(response_text)
# useful function to make search query more optimal, for future explainability studies
# consider returning query phrase and keywords to user
query_phrase, remainder = find_query(response_text)
logger.info("PHRASE:" + query_phrase)
# logger.info(remainder)
keywords = find_keywords(remainder, query_phrase, query_string)
logger.info("KEYWORDS:" + ''.join(keywords))
return query_phrase, keywords
def reform(elements):
# reformulates text extracted from a webpage by unstructured.partition_html into larger keyword-rankable chunks
texts = (
[]
) # a list of text_strings, each of at most *max* chars, separated on '\n' when splitting an element is needed
paragraphs = []
total_elem_len = 0
for element in elements:
text = str(element)
total_elem_len += len(text)
if len(text) < 4:
continue
elif len(text) < 500:
texts.append(text)
else:
subtexts = text.split("\n")
for subtext in subtexts:
if len(subtext) < 500:
texts.append(subtext)
else:
texts.extend(nltk.sent_tokenize(subtext))
# now reassemble shorter texts into chunks
paragraph = ""
total_pp_len = 0
for text in texts:
if len(text) + len(paragraph) < 500:
paragraph += " " + text
else:
if len(paragraph) > 0: # start a new paragraph
paragraphs.append(paragraph)
paragraph = ""
paragraph += text
if len(paragraph) > 0:
paragraphs.append(paragraph + ".\n")
# logger.info(f'\n***** reform elements in {len(elements)}, paragraphs out {len(paragraphs)}')
total_pp_len = 0
for paragraph in paragraphs:
total_pp_len += len(paragraph)
if total_pp_len > 1.2 * total_elem_len:
logger.info(
f"******** reform out > reform in. out: {total_pp_len}, in: {total_elem_len}"
)
return paragraphs
def get_actions(text):
# look for actions in response
action_indecies = re.finditer("Action:", text) # Action: [search, ask} (query)
actions = []
editted_response = text
for action_index in action_indecies:
action = text[action_index.span()[1]:]
agent = None
query = None
query_start = action.find("(")
if query_start > 0:
agent = action[:query_start].strip()
query_end = action[query_start + 1:].find(")")
if query_end > 0:
query = action[query_start + 1: query_start + 1 + query_end]
action = text[
action_index.start(): action_index.span()[1]
+ action_index.start()
+ query_start
+ query_end
+ 2
]
if agent is None or query is None:
logger.info(
"can't parse action, skipping",
text[action_index.start(): action_index.start() + 48],
)
continue
actions.append([agent, query, action])
editted_response = editted_response.replace(action, "")
return actions
if __name__ == "__main__":
get_search_phrase_and_keywords(
"Would I like the video game Forspoken, given that I like Final Fantasy VII?",
[],
)
# logger.info(query_vicuna("what is 5 * 3?"))