Spaces:
Runtime error
Runtime error
Commit
·
593bb22
1
Parent(s):
ba91632
inline citations and more
Browse files- ai_generate.py +104 -10
- app.py +20 -14
ai_generate.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
from langchain_community.document_loaders import PyMuPDFLoader
|
3 |
from langchain_core.documents import Document
|
@@ -18,6 +19,15 @@ from dotenv import load_dotenv
|
|
18 |
from langchain_core.output_parsers import XMLOutputParser
|
19 |
from langchain.prompts import ChatPromptTemplate
|
20 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
load_dotenv()
|
23 |
|
@@ -29,7 +39,17 @@ os.environ["GLOG_minloglevel"] = "2"
|
|
29 |
CHUNK_SIZE = 1024
|
30 |
CHUNK_OVERLAP = CHUNK_SIZE // 8
|
31 |
K = 10
|
32 |
-
FETCH_K =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
llm_model_translation = {
|
35 |
"LLaMA 3": "llama3-70b-8192",
|
@@ -195,10 +215,9 @@ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int
|
|
195 |
return llm
|
196 |
|
197 |
|
198 |
-
def create_db_with_langchain(path: list[str], url_content: dict):
|
199 |
all_docs = []
|
200 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
201 |
-
embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
|
202 |
if path:
|
203 |
for file in path:
|
204 |
loader = PyMuPDFLoader(file)
|
@@ -214,18 +233,38 @@ def create_db_with_langchain(path: list[str], url_content: dict):
|
|
214 |
docs = text_splitter.split_documents([doc])
|
215 |
all_docs.extend(docs)
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
# print docs
|
218 |
for idx, doc in enumerate(all_docs):
|
219 |
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
|
220 |
|
221 |
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
|
222 |
db = Chroma.from_documents(all_docs, embedding_function)
|
|
|
|
|
223 |
return db
|
224 |
|
225 |
|
|
|
|
|
|
|
|
|
226 |
def generate_rag(
|
227 |
prompt: str,
|
|
|
228 |
topic: str,
|
|
|
229 |
model: str,
|
230 |
url_content: dict,
|
231 |
path: list[str],
|
@@ -238,18 +277,24 @@ def generate_rag(
|
|
238 |
if llm is None:
|
239 |
print("Failed to load LLM. Aborting operation.")
|
240 |
return None
|
241 |
-
db = create_db_with_langchain(path, url_content)
|
242 |
-
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
|
243 |
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
formatted_docs = format_docs_xml(docs)
|
247 |
rag_chain = RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser()
|
248 |
result = rag_chain.invoke({"input": prompt})
|
249 |
citations = get_citations(result, docs)
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
return result, citations
|
254 |
|
255 |
|
@@ -271,7 +316,9 @@ def generate_base(
|
|
271 |
|
272 |
def generate(
|
273 |
prompt: str,
|
|
|
274 |
topic: str,
|
|
|
275 |
model: str,
|
276 |
url_content: dict,
|
277 |
path: list[str],
|
@@ -281,6 +328,53 @@ def generate(
|
|
281 |
sys_message="",
|
282 |
):
|
283 |
if path or url_content:
|
284 |
-
return generate_rag(
|
|
|
|
|
285 |
else:
|
286 |
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
import os
|
3 |
from langchain_community.document_loaders import PyMuPDFLoader
|
4 |
from langchain_core.documents import Document
|
|
|
19 |
from langchain_core.output_parsers import XMLOutputParser
|
20 |
from langchain.prompts import ChatPromptTemplate
|
21 |
import re
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
# pip install bm25s
|
26 |
+
import bm25s
|
27 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
28 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
29 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
30 |
+
from langchain_core.messages import HumanMessage
|
31 |
|
32 |
load_dotenv()
|
33 |
|
|
|
39 |
CHUNK_SIZE = 1024
|
40 |
CHUNK_OVERLAP = CHUNK_SIZE // 8
|
41 |
K = 10
|
42 |
+
FETCH_K = 50
|
43 |
+
|
44 |
+
model_kwargs = {"device": "cuda:1"}
|
45 |
+
print("Loading embedding and reranker models...")
|
46 |
+
embedding_function = SentenceTransformerEmbeddings(
|
47 |
+
model_name="mixedbread-ai/mxbai-embed-large-v1", model_kwargs=model_kwargs
|
48 |
+
)
|
49 |
+
# "sentence-transformers/all-MiniLM-L6-v2"
|
50 |
+
# "mixedbread-ai/mxbai-embed-large-v1"
|
51 |
+
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
|
52 |
+
compressor = CrossEncoderReranker(model=reranker, top_n=K)
|
53 |
|
54 |
llm_model_translation = {
|
55 |
"LLaMA 3": "llama3-70b-8192",
|
|
|
215 |
return llm
|
216 |
|
217 |
|
218 |
+
def create_db_with_langchain(path: list[str], url_content: dict, query: str):
|
219 |
all_docs = []
|
220 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
|
|
221 |
if path:
|
222 |
for file in path:
|
223 |
loader = PyMuPDFLoader(file)
|
|
|
233 |
docs = text_splitter.split_documents([doc])
|
234 |
all_docs.extend(docs)
|
235 |
|
236 |
+
print(f"### Total number of documents before bm25s: {len(all_docs)}")
|
237 |
+
|
238 |
+
# if the number of docs is too high, we need to reduce it
|
239 |
+
num_max_docs = 250
|
240 |
+
if len(all_docs) > num_max_docs:
|
241 |
+
docs_raw = [doc.page_content for doc in all_docs]
|
242 |
+
retriever = bm25s.BM25(corpus=docs_raw)
|
243 |
+
retriever.index(bm25s.tokenize(docs_raw))
|
244 |
+
results, scores = retriever.retrieve(bm25s.tokenize(query), k=len(docs_raw), sorted=False)
|
245 |
+
top_indices = np.argpartition(scores[0], -num_max_docs)[-num_max_docs:]
|
246 |
+
all_docs = [all_docs[i] for i in top_indices]
|
247 |
+
|
248 |
# print docs
|
249 |
for idx, doc in enumerate(all_docs):
|
250 |
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
|
251 |
|
252 |
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
|
253 |
db = Chroma.from_documents(all_docs, embedding_function)
|
254 |
+
torch.cuda.empty_cache()
|
255 |
+
gc.collect()
|
256 |
return db
|
257 |
|
258 |
|
259 |
+
def pretty_print_docs(docs):
|
260 |
+
print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
|
261 |
+
|
262 |
+
|
263 |
def generate_rag(
|
264 |
prompt: str,
|
265 |
+
input_role: str,
|
266 |
topic: str,
|
267 |
+
context: str,
|
268 |
model: str,
|
269 |
url_content: dict,
|
270 |
path: list[str],
|
|
|
277 |
if llm is None:
|
278 |
print("Failed to load LLM. Aborting operation.")
|
279 |
return None
|
|
|
|
|
280 |
|
281 |
+
query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
|
282 |
+
print("### Query: ", query)
|
283 |
+
db = create_db_with_langchain(path, url_content, query)
|
284 |
+
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
|
285 |
+
|
286 |
+
# docs = retriever.get_relevant_documents(query)
|
287 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
288 |
+
docs = compression_retriever.invoke(query)
|
289 |
+
print(pretty_print_docs(docs))
|
290 |
|
291 |
formatted_docs = format_docs_xml(docs)
|
292 |
rag_chain = RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser()
|
293 |
result = rag_chain.invoke({"input": prompt})
|
294 |
citations = get_citations(result, docs)
|
295 |
+
db.delete_collection() # important, othwerwise it will keep the documents in memory
|
296 |
+
torch.cuda.empty_cache()
|
297 |
+
gc.collect()
|
298 |
return result, citations
|
299 |
|
300 |
|
|
|
316 |
|
317 |
def generate(
|
318 |
prompt: str,
|
319 |
+
input_role: str,
|
320 |
topic: str,
|
321 |
+
context: str,
|
322 |
model: str,
|
323 |
url_content: dict,
|
324 |
path: list[str],
|
|
|
328 |
sys_message="",
|
329 |
):
|
330 |
if path or url_content:
|
331 |
+
return generate_rag(
|
332 |
+
prompt, input_role, topic, context, model, url_content, path, temperature, max_length, api_key, sys_message
|
333 |
+
)
|
334 |
else:
|
335 |
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|
336 |
+
|
337 |
+
|
338 |
+
def llm_wrapper(
|
339 |
+
iam=None,
|
340 |
+
topic=None,
|
341 |
+
context=None,
|
342 |
+
temperature=1.0,
|
343 |
+
max_length=512,
|
344 |
+
api_key="",
|
345 |
+
model="OpenAI GPT 4o Mini",
|
346 |
+
task_type="internet",
|
347 |
+
):
|
348 |
+
llm = load_llm(model, api_key, temperature, max_length)
|
349 |
+
|
350 |
+
if task_type == "rag":
|
351 |
+
system_message_content = """You are an AI assistant tasked with reformulating user inputs to improve retrieval query in a RAG system.
|
352 |
+
- Given the original user inputs, construct query to be more specific, detailed, and likely to retrieve relevant information.
|
353 |
+
- Generate the query as a complete sentence or question, not just as keywords, to ensure the retrieval process can find detailed and contextually relevant information.
|
354 |
+
- You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
|
355 |
+
|
356 |
+
**Inputs:**
|
357 |
+
- **User Role**: {iam}
|
358 |
+
- **Topic**: {topic}
|
359 |
+
- **Context**: {context}
|
360 |
+
|
361 |
+
**Only return the search query**."""
|
362 |
+
elif task_type == "internet":
|
363 |
+
system_message_content = """You are an AI assistant tasked with generating an optimized Google search query to help retrieve relevant websites, news, articles, and other sources of information.
|
364 |
+
- You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
|
365 |
+
- The query should be **concise** and include important **keywords** while incorporating **short phrases** or context where it improves the search.
|
366 |
+
- Avoid the use of "site:" operators or narrowing search by specific websites.
|
367 |
+
|
368 |
+
**Inputs:**
|
369 |
+
- **User Role**: {iam}
|
370 |
+
- **Topic**: {topic}
|
371 |
+
- **Context**: {context}
|
372 |
+
|
373 |
+
**Only return the search query**.
|
374 |
+
"""
|
375 |
+
else:
|
376 |
+
raise ValueError("Task type not recognized. Please specify 'rag' or 'internet'.")
|
377 |
+
|
378 |
+
human_message = HumanMessage(content=system_message_content.format(iam=iam, topic=topic, context=context))
|
379 |
+
response = llm.invoke([human_message])
|
380 |
+
return response.content.strip('"').strip("'")
|
app.py
CHANGED
@@ -3,6 +3,7 @@ nohup python3 app.py &
|
|
3 |
export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
|
4 |
"""
|
5 |
|
|
|
6 |
import re
|
7 |
import uuid
|
8 |
import json
|
@@ -23,13 +24,12 @@ if gr.NO_RELOAD:
|
|
23 |
from humanize import humanize_text, device
|
24 |
from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
|
25 |
from google_search import google_search, months, domain_list, build_date
|
26 |
-
from ai_generate import generate, citations_to_html, remove_citations, display_cited_text
|
27 |
|
28 |
-
if gr.NO_RELOAD:
|
29 |
nltk.download("punkt_tab")
|
30 |
|
31 |
print(f"Using device: {device}")
|
32 |
-
|
33 |
models = {
|
34 |
"Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
|
35 |
"polygraf-ai/bc-roberta-openai-2sent"
|
@@ -51,6 +51,7 @@ if gr.NO_RELOAD:
|
|
51 |
TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
|
52 |
MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
|
53 |
text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
|
|
|
54 |
text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
|
55 |
|
56 |
|
@@ -64,20 +65,20 @@ def generate_cited_html(cited_text, citations: dict):
|
|
64 |
}
|
65 |
.reference-btn {
|
66 |
display: inline-block;
|
67 |
-
width:
|
68 |
-
height:
|
69 |
border-radius: 50%;
|
70 |
-
background-color: #
|
71 |
color: white;
|
72 |
text-align: center;
|
73 |
-
line-height:
|
74 |
cursor: pointer;
|
75 |
font-weight: bold;
|
76 |
margin-right: 5px;
|
77 |
transition: background-color 0.3s ease, transform 0.3s ease;
|
78 |
}
|
79 |
.reference-btn:hover {
|
80 |
-
background-color: #
|
81 |
transform: scale(1.1); /* Slightly enlarge on hover */
|
82 |
}
|
83 |
.reference-popup {
|
@@ -357,6 +358,8 @@ def predict(model, tokenizer, text):
|
|
357 |
output = model(**tokens)
|
358 |
output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
|
359 |
output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
|
|
|
|
|
360 |
return output_norm
|
361 |
|
362 |
|
@@ -428,6 +431,8 @@ def predict_mc(text):
|
|
428 |
).to(device)
|
429 |
output = text_mc_model(**tokens)
|
430 |
output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
|
|
|
|
|
431 |
return output_norm
|
432 |
|
433 |
|
@@ -582,7 +587,9 @@ def generate_article(
|
|
582 |
print("Generated Prompt...\n", prompt)
|
583 |
article, citations = generate(
|
584 |
prompt=prompt,
|
|
|
585 |
topic=topic,
|
|
|
586 |
model=ai_model,
|
587 |
url_content=url_content,
|
588 |
path=pdf_file_input,
|
@@ -692,7 +699,6 @@ def save_to_cloud_storage(
|
|
692 |
article,
|
693 |
topic,
|
694 |
input_role,
|
695 |
-
topic_context,
|
696 |
context,
|
697 |
keywords,
|
698 |
article_length,
|
@@ -725,7 +731,6 @@ def save_to_cloud_storage(
|
|
725 |
"metadata": {
|
726 |
"topic": topic,
|
727 |
"input_role": input_role,
|
728 |
-
"topic_context": topic_context,
|
729 |
"context": context,
|
730 |
"keywords": keywords,
|
731 |
"article_length": article_length,
|
@@ -818,7 +823,9 @@ def generate_and_format(
|
|
818 |
date_from = build_date(year_from, month_from, day_from)
|
819 |
date_to = build_date(year_to, month_to, day_to)
|
820 |
sorted_date = f"date:r:{date_from}:{date_to}"
|
821 |
-
final_query =
|
|
|
|
|
822 |
if include_sites:
|
823 |
site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
|
824 |
final_query += " " + " OR ".join(site_queries)
|
@@ -827,10 +834,10 @@ def generate_and_format(
|
|
827 |
final_query += " " + " ".join(exclude_queries)
|
828 |
print(f"Google Search Query: {final_query}")
|
829 |
url_content = google_search(final_query, sorted_date, domains_to_include)
|
830 |
-
topic_context = topic + ", " + context
|
831 |
article, citations = generate_article(
|
832 |
input_role,
|
833 |
-
|
834 |
context,
|
835 |
keywords,
|
836 |
article_length,
|
@@ -866,7 +873,6 @@ def generate_and_format(
|
|
866 |
article,
|
867 |
topic,
|
868 |
input_role,
|
869 |
-
topic_context,
|
870 |
context,
|
871 |
keywords,
|
872 |
article_length,
|
|
|
3 |
export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
|
4 |
"""
|
5 |
|
6 |
+
import gc
|
7 |
import re
|
8 |
import uuid
|
9 |
import json
|
|
|
24 |
from humanize import humanize_text, device
|
25 |
from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
|
26 |
from google_search import google_search, months, domain_list, build_date
|
27 |
+
from ai_generate import generate, citations_to_html, remove_citations, display_cited_text, llm_wrapper
|
28 |
|
|
|
29 |
nltk.download("punkt_tab")
|
30 |
|
31 |
print(f"Using device: {device}")
|
32 |
+
print("Loading AI detection models...")
|
33 |
models = {
|
34 |
"Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
|
35 |
"polygraf-ai/bc-roberta-openai-2sent"
|
|
|
51 |
TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
|
52 |
MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
|
53 |
text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
|
54 |
+
print("Loading Source detection model...")
|
55 |
text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
|
56 |
|
57 |
|
|
|
65 |
}
|
66 |
.reference-btn {
|
67 |
display: inline-block;
|
68 |
+
width: 20px; /* Reduced width */
|
69 |
+
height: 20px; /* Reduced height */
|
70 |
border-radius: 50%;
|
71 |
+
background-color: #e33a89; /* Pink color for the button */
|
72 |
color: white;
|
73 |
text-align: center;
|
74 |
+
line-height: 20px; /* Adjusted line-height */
|
75 |
cursor: pointer;
|
76 |
font-weight: bold;
|
77 |
margin-right: 5px;
|
78 |
transition: background-color 0.3s ease, transform 0.3s ease;
|
79 |
}
|
80 |
.reference-btn:hover {
|
81 |
+
background-color: #ff69b4; /* Lighter pink on hover */
|
82 |
transform: scale(1.1); /* Slightly enlarge on hover */
|
83 |
}
|
84 |
.reference-popup {
|
|
|
358 |
output = model(**tokens)
|
359 |
output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
|
360 |
output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
|
361 |
+
torch.cuda.empty_cache()
|
362 |
+
gc.collect()
|
363 |
return output_norm
|
364 |
|
365 |
|
|
|
431 |
).to(device)
|
432 |
output = text_mc_model(**tokens)
|
433 |
output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
|
434 |
+
torch.cuda.empty_cache()
|
435 |
+
gc.collect()
|
436 |
return output_norm
|
437 |
|
438 |
|
|
|
587 |
print("Generated Prompt...\n", prompt)
|
588 |
article, citations = generate(
|
589 |
prompt=prompt,
|
590 |
+
input_role=input_role,
|
591 |
topic=topic,
|
592 |
+
context=context,
|
593 |
model=ai_model,
|
594 |
url_content=url_content,
|
595 |
path=pdf_file_input,
|
|
|
699 |
article,
|
700 |
topic,
|
701 |
input_role,
|
|
|
702 |
context,
|
703 |
keywords,
|
704 |
article_length,
|
|
|
731 |
"metadata": {
|
732 |
"topic": topic,
|
733 |
"input_role": input_role,
|
|
|
734 |
"context": context,
|
735 |
"keywords": keywords,
|
736 |
"article_length": article_length,
|
|
|
823 |
date_from = build_date(year_from, month_from, day_from)
|
824 |
date_to = build_date(year_to, month_to, day_to)
|
825 |
sorted_date = f"date:r:{date_from}:{date_to}"
|
826 |
+
final_query = llm_wrapper(
|
827 |
+
input_role, topic, context, model="OpenAI GPT 4o", task_type="internet", temperature=0.7
|
828 |
+
)
|
829 |
if include_sites:
|
830 |
site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
|
831 |
final_query += " " + " OR ".join(site_queries)
|
|
|
834 |
final_query += " " + " ".join(exclude_queries)
|
835 |
print(f"Google Search Query: {final_query}")
|
836 |
url_content = google_search(final_query, sorted_date, domains_to_include)
|
837 |
+
# topic_context = topic + ", " + context
|
838 |
article, citations = generate_article(
|
839 |
input_role,
|
840 |
+
topic,
|
841 |
context,
|
842 |
keywords,
|
843 |
article_length,
|
|
|
873 |
article,
|
874 |
topic,
|
875 |
input_role,
|
|
|
876 |
context,
|
877 |
keywords,
|
878 |
article_length,
|