from climateqa.engine.keywords import make_keywords_chain
from climateqa.engine.llm import get_llm
from climateqa.knowledge.openalex import OpenAlex
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
from front.utils import make_html_papers
from climateqa.engine.reranker import get_reranker
oa = OpenAlex()
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
reranker = get_reranker("nano")
papers_cols_widths = {
"id":100,
"title":300,
"doi":100,
"publication_year":100,
"abstract":500,
"is_oa":50,
}
papers_cols = list(papers_cols_widths.keys())
papers_cols_widths = list(papers_cols_widths.values())
def generate_keywords(query):
chain = make_keywords_chain(llm)
keywords = chain.invoke(query)
keywords = " AND ".join(keywords["keywords"])
return keywords
async def find_papers(query,after, relevant_content_sources, reranker= reranker):
if "OpenAlex" in relevant_content_sources:
summary = ""
keywords = generate_keywords(query)
df_works = oa.search(keywords,after = after)
print(f"Found {len(df_works)} papers")
if not df_works.empty:
df_works = df_works.dropna(subset=["abstract"])
df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
df_works = oa.rerank(query,df_works,reranker)
df_works = df_works.sort_values("rerank_score",ascending=False)
docs_html = []
for i in range(10):
docs_html.append(make_html_papers(df_works, i))
docs_html = "".join(docs_html)
G = oa.make_network(df_works)
height = "750px"
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
network_html = network.generate_html()
network_html = network_html.replace("'", "\"")
css_to_inject = ""
network_html = network_html + css_to_inject
network_html = f""""""
docs = df_works["content"].head(10).tolist()
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
df_works["doc"] = df_works["doc"] + 1
df_works = df_works[papers_cols]
yield docs_html, network_html, summary
chain = make_rag_papers_chain(llm)
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
path_answer = "/logs/StrOutputParser/streamed_output/-"
async for op in result:
op = op.ops[0]
if op['path'] == path_answer: # reforulated question
new_token = op['value'] # str
summary += new_token
else:
continue
yield docs_html, network_html, summary
else :
print("No papers found")
else :
yield "","", ""