eljanmahammadli commited on
Commit
593bb22
·
1 Parent(s): ba91632

inline citations and more

Browse files
Files changed (2) hide show
  1. ai_generate.py +104 -10
  2. app.py +20 -14
ai_generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from langchain_community.document_loaders import PyMuPDFLoader
3
  from langchain_core.documents import Document
@@ -18,6 +19,15 @@ from dotenv import load_dotenv
18
  from langchain_core.output_parsers import XMLOutputParser
19
  from langchain.prompts import ChatPromptTemplate
20
  import re
 
 
 
 
 
 
 
 
 
21
 
22
  load_dotenv()
23
 
@@ -29,7 +39,17 @@ os.environ["GLOG_minloglevel"] = "2"
29
  CHUNK_SIZE = 1024
30
  CHUNK_OVERLAP = CHUNK_SIZE // 8
31
  K = 10
32
- FETCH_K = 20
 
 
 
 
 
 
 
 
 
 
33
 
34
  llm_model_translation = {
35
  "LLaMA 3": "llama3-70b-8192",
@@ -195,10 +215,9 @@ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int
195
  return llm
196
 
197
 
198
- def create_db_with_langchain(path: list[str], url_content: dict):
199
  all_docs = []
200
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
201
- embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
202
  if path:
203
  for file in path:
204
  loader = PyMuPDFLoader(file)
@@ -214,18 +233,38 @@ def create_db_with_langchain(path: list[str], url_content: dict):
214
  docs = text_splitter.split_documents([doc])
215
  all_docs.extend(docs)
216
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  # print docs
218
  for idx, doc in enumerate(all_docs):
219
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
220
 
221
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
222
  db = Chroma.from_documents(all_docs, embedding_function)
 
 
223
  return db
224
 
225
 
 
 
 
 
226
  def generate_rag(
227
  prompt: str,
 
228
  topic: str,
 
229
  model: str,
230
  url_content: dict,
231
  path: list[str],
@@ -238,18 +277,24 @@ def generate_rag(
238
  if llm is None:
239
  print("Failed to load LLM. Aborting operation.")
240
  return None
241
- db = create_db_with_langchain(path, url_content)
242
- retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
243
 
244
- docs = retriever.get_relevant_documents(topic)
 
 
 
 
 
 
 
 
245
 
246
  formatted_docs = format_docs_xml(docs)
247
  rag_chain = RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser()
248
  result = rag_chain.invoke({"input": prompt})
249
  citations = get_citations(result, docs)
250
-
251
- db.delete_collection() # delete othewise there could be duplicates because of the cache
252
-
253
  return result, citations
254
 
255
 
@@ -271,7 +316,9 @@ def generate_base(
271
 
272
  def generate(
273
  prompt: str,
 
274
  topic: str,
 
275
  model: str,
276
  url_content: dict,
277
  path: list[str],
@@ -281,6 +328,53 @@ def generate(
281
  sys_message="",
282
  ):
283
  if path or url_content:
284
- return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
 
 
285
  else:
286
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
  import os
3
  from langchain_community.document_loaders import PyMuPDFLoader
4
  from langchain_core.documents import Document
 
19
  from langchain_core.output_parsers import XMLOutputParser
20
  from langchain.prompts import ChatPromptTemplate
21
  import re
22
+ import numpy as np
23
+ import torch
24
+
25
+ # pip install bm25s
26
+ import bm25s
27
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
28
+ from langchain.retrievers import ContextualCompressionRetriever
29
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
30
+ from langchain_core.messages import HumanMessage
31
 
32
  load_dotenv()
33
 
 
39
  CHUNK_SIZE = 1024
40
  CHUNK_OVERLAP = CHUNK_SIZE // 8
41
  K = 10
42
+ FETCH_K = 50
43
+
44
+ model_kwargs = {"device": "cuda:1"}
45
+ print("Loading embedding and reranker models...")
46
+ embedding_function = SentenceTransformerEmbeddings(
47
+ model_name="mixedbread-ai/mxbai-embed-large-v1", model_kwargs=model_kwargs
48
+ )
49
+ # "sentence-transformers/all-MiniLM-L6-v2"
50
+ # "mixedbread-ai/mxbai-embed-large-v1"
51
+ reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
52
+ compressor = CrossEncoderReranker(model=reranker, top_n=K)
53
 
54
  llm_model_translation = {
55
  "LLaMA 3": "llama3-70b-8192",
 
215
  return llm
216
 
217
 
218
+ def create_db_with_langchain(path: list[str], url_content: dict, query: str):
219
  all_docs = []
220
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
 
221
  if path:
222
  for file in path:
223
  loader = PyMuPDFLoader(file)
 
233
  docs = text_splitter.split_documents([doc])
234
  all_docs.extend(docs)
235
 
236
+ print(f"### Total number of documents before bm25s: {len(all_docs)}")
237
+
238
+ # if the number of docs is too high, we need to reduce it
239
+ num_max_docs = 250
240
+ if len(all_docs) > num_max_docs:
241
+ docs_raw = [doc.page_content for doc in all_docs]
242
+ retriever = bm25s.BM25(corpus=docs_raw)
243
+ retriever.index(bm25s.tokenize(docs_raw))
244
+ results, scores = retriever.retrieve(bm25s.tokenize(query), k=len(docs_raw), sorted=False)
245
+ top_indices = np.argpartition(scores[0], -num_max_docs)[-num_max_docs:]
246
+ all_docs = [all_docs[i] for i in top_indices]
247
+
248
  # print docs
249
  for idx, doc in enumerate(all_docs):
250
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
251
 
252
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
253
  db = Chroma.from_documents(all_docs, embedding_function)
254
+ torch.cuda.empty_cache()
255
+ gc.collect()
256
  return db
257
 
258
 
259
+ def pretty_print_docs(docs):
260
+ print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
261
+
262
+
263
  def generate_rag(
264
  prompt: str,
265
+ input_role: str,
266
  topic: str,
267
+ context: str,
268
  model: str,
269
  url_content: dict,
270
  path: list[str],
 
277
  if llm is None:
278
  print("Failed to load LLM. Aborting operation.")
279
  return None
 
 
280
 
281
+ query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
282
+ print("### Query: ", query)
283
+ db = create_db_with_langchain(path, url_content, query)
284
+ retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
285
+
286
+ # docs = retriever.get_relevant_documents(query)
287
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
288
+ docs = compression_retriever.invoke(query)
289
+ print(pretty_print_docs(docs))
290
 
291
  formatted_docs = format_docs_xml(docs)
292
  rag_chain = RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser()
293
  result = rag_chain.invoke({"input": prompt})
294
  citations = get_citations(result, docs)
295
+ db.delete_collection() # important, othwerwise it will keep the documents in memory
296
+ torch.cuda.empty_cache()
297
+ gc.collect()
298
  return result, citations
299
 
300
 
 
316
 
317
  def generate(
318
  prompt: str,
319
+ input_role: str,
320
  topic: str,
321
+ context: str,
322
  model: str,
323
  url_content: dict,
324
  path: list[str],
 
328
  sys_message="",
329
  ):
330
  if path or url_content:
331
+ return generate_rag(
332
+ prompt, input_role, topic, context, model, url_content, path, temperature, max_length, api_key, sys_message
333
+ )
334
  else:
335
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
336
+
337
+
338
+ def llm_wrapper(
339
+ iam=None,
340
+ topic=None,
341
+ context=None,
342
+ temperature=1.0,
343
+ max_length=512,
344
+ api_key="",
345
+ model="OpenAI GPT 4o Mini",
346
+ task_type="internet",
347
+ ):
348
+ llm = load_llm(model, api_key, temperature, max_length)
349
+
350
+ if task_type == "rag":
351
+ system_message_content = """You are an AI assistant tasked with reformulating user inputs to improve retrieval query in a RAG system.
352
+ - Given the original user inputs, construct query to be more specific, detailed, and likely to retrieve relevant information.
353
+ - Generate the query as a complete sentence or question, not just as keywords, to ensure the retrieval process can find detailed and contextually relevant information.
354
+ - You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
355
+
356
+ **Inputs:**
357
+ - **User Role**: {iam}
358
+ - **Topic**: {topic}
359
+ - **Context**: {context}
360
+
361
+ **Only return the search query**."""
362
+ elif task_type == "internet":
363
+ system_message_content = """You are an AI assistant tasked with generating an optimized Google search query to help retrieve relevant websites, news, articles, and other sources of information.
364
+ - You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
365
+ - The query should be **concise** and include important **keywords** while incorporating **short phrases** or context where it improves the search.
366
+ - Avoid the use of "site:" operators or narrowing search by specific websites.
367
+
368
+ **Inputs:**
369
+ - **User Role**: {iam}
370
+ - **Topic**: {topic}
371
+ - **Context**: {context}
372
+
373
+ **Only return the search query**.
374
+ """
375
+ else:
376
+ raise ValueError("Task type not recognized. Please specify 'rag' or 'internet'.")
377
+
378
+ human_message = HumanMessage(content=system_message_content.format(iam=iam, topic=topic, context=context))
379
+ response = llm.invoke([human_message])
380
+ return response.content.strip('"').strip("'")
app.py CHANGED
@@ -3,6 +3,7 @@ nohup python3 app.py &
3
  export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
4
  """
5
 
 
6
  import re
7
  import uuid
8
  import json
@@ -23,13 +24,12 @@ if gr.NO_RELOAD:
23
  from humanize import humanize_text, device
24
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
25
  from google_search import google_search, months, domain_list, build_date
26
- from ai_generate import generate, citations_to_html, remove_citations, display_cited_text
27
 
28
- if gr.NO_RELOAD:
29
  nltk.download("punkt_tab")
30
 
31
  print(f"Using device: {device}")
32
-
33
  models = {
34
  "Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
35
  "polygraf-ai/bc-roberta-openai-2sent"
@@ -51,6 +51,7 @@ if gr.NO_RELOAD:
51
  TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
52
  MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
53
  text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
 
54
  text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
55
 
56
 
@@ -64,20 +65,20 @@ def generate_cited_html(cited_text, citations: dict):
64
  }
65
  .reference-btn {
66
  display: inline-block;
67
- width: 25px;
68
- height: 25px;
69
  border-radius: 50%;
70
- background-color: #0000EE; /* Blue color for the button */
71
  color: white;
72
  text-align: center;
73
- line-height: 25px;
74
  cursor: pointer;
75
  font-weight: bold;
76
  margin-right: 5px;
77
  transition: background-color 0.3s ease, transform 0.3s ease;
78
  }
79
  .reference-btn:hover {
80
- background-color: #1e90ff; /* Lighter blue on hover */
81
  transform: scale(1.1); /* Slightly enlarge on hover */
82
  }
83
  .reference-popup {
@@ -357,6 +358,8 @@ def predict(model, tokenizer, text):
357
  output = model(**tokens)
358
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
359
  output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
 
 
360
  return output_norm
361
 
362
 
@@ -428,6 +431,8 @@ def predict_mc(text):
428
  ).to(device)
429
  output = text_mc_model(**tokens)
430
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
 
 
431
  return output_norm
432
 
433
 
@@ -582,7 +587,9 @@ def generate_article(
582
  print("Generated Prompt...\n", prompt)
583
  article, citations = generate(
584
  prompt=prompt,
 
585
  topic=topic,
 
586
  model=ai_model,
587
  url_content=url_content,
588
  path=pdf_file_input,
@@ -692,7 +699,6 @@ def save_to_cloud_storage(
692
  article,
693
  topic,
694
  input_role,
695
- topic_context,
696
  context,
697
  keywords,
698
  article_length,
@@ -725,7 +731,6 @@ def save_to_cloud_storage(
725
  "metadata": {
726
  "topic": topic,
727
  "input_role": input_role,
728
- "topic_context": topic_context,
729
  "context": context,
730
  "keywords": keywords,
731
  "article_length": article_length,
@@ -818,7 +823,9 @@ def generate_and_format(
818
  date_from = build_date(year_from, month_from, day_from)
819
  date_to = build_date(year_to, month_to, day_to)
820
  sorted_date = f"date:r:{date_from}:{date_to}"
821
- final_query = topic
 
 
822
  if include_sites:
823
  site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
824
  final_query += " " + " OR ".join(site_queries)
@@ -827,10 +834,10 @@ def generate_and_format(
827
  final_query += " " + " ".join(exclude_queries)
828
  print(f"Google Search Query: {final_query}")
829
  url_content = google_search(final_query, sorted_date, domains_to_include)
830
- topic_context = topic + ", " + context
831
  article, citations = generate_article(
832
  input_role,
833
- topic_context,
834
  context,
835
  keywords,
836
  article_length,
@@ -866,7 +873,6 @@ def generate_and_format(
866
  article,
867
  topic,
868
  input_role,
869
- topic_context,
870
  context,
871
  keywords,
872
  article_length,
 
3
  export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
4
  """
5
 
6
+ import gc
7
  import re
8
  import uuid
9
  import json
 
24
  from humanize import humanize_text, device
25
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
26
  from google_search import google_search, months, domain_list, build_date
27
+ from ai_generate import generate, citations_to_html, remove_citations, display_cited_text, llm_wrapper
28
 
 
29
  nltk.download("punkt_tab")
30
 
31
  print(f"Using device: {device}")
32
+ print("Loading AI detection models...")
33
  models = {
34
  "Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
35
  "polygraf-ai/bc-roberta-openai-2sent"
 
51
  TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
52
  MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
53
  text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
54
+ print("Loading Source detection model...")
55
  text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
56
 
57
 
 
65
  }
66
  .reference-btn {
67
  display: inline-block;
68
+ width: 20px; /* Reduced width */
69
+ height: 20px; /* Reduced height */
70
  border-radius: 50%;
71
+ background-color: #e33a89; /* Pink color for the button */
72
  color: white;
73
  text-align: center;
74
+ line-height: 20px; /* Adjusted line-height */
75
  cursor: pointer;
76
  font-weight: bold;
77
  margin-right: 5px;
78
  transition: background-color 0.3s ease, transform 0.3s ease;
79
  }
80
  .reference-btn:hover {
81
+ background-color: #ff69b4; /* Lighter pink on hover */
82
  transform: scale(1.1); /* Slightly enlarge on hover */
83
  }
84
  .reference-popup {
 
358
  output = model(**tokens)
359
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
360
  output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
361
+ torch.cuda.empty_cache()
362
+ gc.collect()
363
  return output_norm
364
 
365
 
 
431
  ).to(device)
432
  output = text_mc_model(**tokens)
433
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
434
+ torch.cuda.empty_cache()
435
+ gc.collect()
436
  return output_norm
437
 
438
 
 
587
  print("Generated Prompt...\n", prompt)
588
  article, citations = generate(
589
  prompt=prompt,
590
+ input_role=input_role,
591
  topic=topic,
592
+ context=context,
593
  model=ai_model,
594
  url_content=url_content,
595
  path=pdf_file_input,
 
699
  article,
700
  topic,
701
  input_role,
 
702
  context,
703
  keywords,
704
  article_length,
 
731
  "metadata": {
732
  "topic": topic,
733
  "input_role": input_role,
 
734
  "context": context,
735
  "keywords": keywords,
736
  "article_length": article_length,
 
823
  date_from = build_date(year_from, month_from, day_from)
824
  date_to = build_date(year_to, month_to, day_to)
825
  sorted_date = f"date:r:{date_from}:{date_to}"
826
+ final_query = llm_wrapper(
827
+ input_role, topic, context, model="OpenAI GPT 4o", task_type="internet", temperature=0.7
828
+ )
829
  if include_sites:
830
  site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
831
  final_query += " " + " OR ".join(site_queries)
 
834
  final_query += " " + " ".join(exclude_queries)
835
  print(f"Google Search Query: {final_query}")
836
  url_content = google_search(final_query, sorted_date, domains_to_include)
837
+ # topic_context = topic + ", " + context
838
  article, citations = generate_article(
839
  input_role,
840
+ topic,
841
  context,
842
  keywords,
843
  article_length,
 
873
  article,
874
  topic,
875
  input_role,
 
876
  context,
877
  keywords,
878
  article_length,