jianghuyihei commited on
Commit
e3a17c0
1 Parent(s): a81bf47

delete async

Browse files
Files changed (6) hide show
  1. .gitattributes copy +0 -35
  2. LLM.py +43 -8
  3. agents.py +34 -73
  4. app.py +1 -1
  5. main.py +4 -9
  6. searcher/sementic_search.py +49 -121
.gitattributes copy DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLM.py CHANGED
@@ -123,7 +123,13 @@ class openai_llm(base_llm):
123
  input=text,
124
  timeout= 180
125
  )
126
- return embbeding.data[0].embedding
 
 
 
 
 
 
127
  except Exception as e:
128
  print(f"get embbeding failed: {e}")
129
  print(e)
@@ -147,7 +153,13 @@ class openai_llm(base_llm):
147
  input=text,
148
  timeout= 180
149
  )
150
- return embbeding.data[0].embedding
 
 
 
 
 
 
151
  except Exception as e:
152
  await asyncio.sleep(0.1)
153
  print(f"get embbeding failed: {e}")
@@ -178,9 +190,32 @@ class openai_llm(base_llm):
178
 
179
 
180
  if __name__ == "__main__":
181
- llm = gemini_llm(api_key="")
182
- prompt = """
183
- """
184
- messages = [{"role":"user","content":prompt}]
185
- response = asyncio.run(llm.response_async(messages))
186
- print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  input=text,
124
  timeout= 180
125
  )
126
+ embbeding = embbeding.data
127
+ if len(embbeding) == 0:
128
+ return None
129
+ elif len(embbeding) == 1:
130
+ return embbeding[0].embedding
131
+ else:
132
+ return [e.embedding for e in embbeding]
133
  except Exception as e:
134
  print(f"get embbeding failed: {e}")
135
  print(e)
 
153
  input=text,
154
  timeout= 180
155
  )
156
+ embbeding = embbeding.data
157
+ if len(embbeding) == 0:
158
+ return None
159
+ elif len(embbeding) == 1:
160
+ return embbeding[0].embedding
161
+ else:
162
+ return [e.embedding for e in embbeding]
163
  except Exception as e:
164
  await asyncio.sleep(0.1)
165
  print(f"get embbeding failed: {e}")
 
190
 
191
 
192
  if __name__ == "__main__":
193
+ import os
194
+ import yaml
195
+
196
+ def cal_cosine_similarity_matric(matric1, matric2):
197
+ if isinstance(matric1, list):
198
+ matric1 = np.array(matric1)
199
+ if isinstance(matric2, list):
200
+ matric2 = np.array(matric2)
201
+ if len(matric1.shape) == 1:
202
+ matric1 = matric1.reshape(1, -1)
203
+ if len(matric2.shape) == 1:
204
+ matric2 = matric2.reshape(1, -1)
205
+ dot_product = np.dot(matric1, matric2.T)
206
+ norm1 = np.linalg.norm(matric1, axis=1)
207
+ norm2 = np.linalg.norm(matric2, axis=1)
208
+
209
+ cos_sim = dot_product / np.outer(norm1, norm2)
210
+ scores = cos_sim.flatten()
211
+ # 返回一个list
212
+ return scores.tolist()
213
+
214
+ texts = ["What is the capital of France?","What is the capital of Spain?", "What is the capital of Italy?", "What is the capital of Germany?"]
215
+ text = "What is the capital of France?"
216
+ llm = openai_llm()
217
+ embbedings = llm.get_embbeding(texts)
218
+ embbeding = llm.get_embbeding(text)
219
+
220
+ scores = cal_cosine_similarity_matric(embbedings, embbeding)
221
+ print(scores)
agents.py CHANGED
@@ -1,7 +1,5 @@
1
  import json
2
  import time
3
- import asyncio
4
- import os
5
  from searcher import Result,SementicSearcher
6
  from LLM import openai_llm
7
  from prompts import *
@@ -17,10 +15,10 @@ def get_llms():
17
  cheap_llm = get_llm("gpt-4o-mini")
18
  return main_llm,cheap_llm
19
 
20
- async def judge_idea(i,j,idea0,idea1,topic,llm):
21
  prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
22
  messages = [{"role":"user","content":prompt}]
23
- response = await llm.response_async(messages)
24
  novelty = extract(response,"novelty")
25
  relevance = extract(response,"relevance")
26
  significance = extract(response,"significance")
@@ -55,16 +53,16 @@ class DeepResearchAgent:
55
  def wrap_messages(self,prompt):
56
  return [{"role":"user","content":prompt}]
57
 
58
- async def get_openai_response_async(self,messages):
59
- return await self.llm.response_async(messages)
60
 
61
- async def get_cheap_openai_response_async(self,messages):
62
- return await self.cheap_llm.response_async(messages,max_tokens = 16000)
63
 
64
- async def get_search_query(self,topic = None,query=None):
65
  prompt = get_deep_search_query_prompt(topic,query)
66
  messages = self.wrap_messages(prompt)
67
- response = await self.get_openai_response_async(messages)
68
  search_query = extract(response,"queries")
69
  try:
70
  search_query = json.loads(search_query)
@@ -73,17 +71,17 @@ class DeepResearchAgent:
73
  search_query = [query]
74
  return search_query
75
 
76
- async def generate_idea_with_chain(self,topic):
77
  self.topic = topic
78
  print(f"begin to generate search query for {topic}")
79
- search_query = await self.get_search_query(topic=topic)
80
  papers = []
81
  for query in search_query:
82
  failed_query = []
83
  current_papers = []
84
  cnt = 0
85
  while len(current_papers) == 0 and cnt < 10:
86
- paper = await self.reader.search_async(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
87
  if paper and len(paper) > 0 and paper[0]:
88
  self.read_papers.add(paper[0].title)
89
  current_papers.append(paper[0])
@@ -91,7 +89,7 @@ class DeepResearchAgent:
91
  failed_query.append(query)
92
  prompt = get_deep_rewrite_query_prompt(failed_query,topic)
93
  messages = self.wrap_messages(prompt)
94
- new_query = await self.get_openai_response_async(messages)
95
  new_query = extract(new_query,"query")
96
  print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
97
  query = new_query
@@ -104,67 +102,30 @@ class DeepResearchAgent:
104
  print(f"failed to generate idea {topic}")
105
  return None,None,None,None,None,None,None,None,None
106
 
107
- tasks = [self.deep_research_paper_with_chain(paper) for paper in papers]
108
- results = await asyncio.gather(*tasks)
109
- results = [result for result in results if result]
110
- if len(results) ==0:
111
- print(f"failed to generate idea {topic}")
112
- return None,None,None,None,None,None,None,None,None
113
-
114
- ideas,idea_chains,experiments,entities,trends,futures,humans,years = [[result[i] for result in results] for i in range(8)]
115
-
116
- tasks = []
117
- for i,idea_1 in enumerate(ideas):
118
- for j,idea_2 in enumerate(ideas):
119
- if i != j:
120
- tasks.append(judge_idea(i,j,idea_1,idea_2,topic,self.llm))
121
- results = await asyncio.gather(*tasks)
122
- elo_scores = [0 for _ in range(len(ideas))]
123
- elo_selected = 0
124
- def change_winner_to_score(winner,score_1,score_2):
125
- try:
126
- winner = int(winner)
127
- except:
128
- return score_1+0.5,score_2+0.5
129
- if winner == 0:
130
- return score_1+1,score_2
131
- if winner == 2:
132
- return score_1+0.5,score_2+0.5
133
- return score_1,score_2+1
134
- for result in results:
135
- i,j,novelty,relevance,significance,clarity,feasibility,effectiveness = result
136
- for dimension in [novelty,relevance,significance,clarity,feasibility,effectiveness]:
137
- elo_scores[i],elo_scores[j] = change_winner_to_score(dimension,elo_scores[i],elo_scores[j])
138
- print(f"i:{i},j:{j},novelty:{novelty},relevance:{relevance},significance:{significance},clarity:{clarity},feasibility:{feasibility},effectiveness:{effectiveness}")
139
- print(elo_scores)
140
- try:
141
- elo_selected = elo_scores.index(max(elo_scores))
142
- except:
143
- elo_selected = 0
144
 
145
- idea,experiment,entities,idea_chain,trend,future,human,year = ideas[elo_selected],experiments[elo_selected],entities[elo_selected],idea_chains[elo_selected],trends[elo_selected],futures[elo_selected],humans[elo_selected],years[elo_selected]
146
  print(f"successfully generated idea")
147
- return idea,experiment,entities,idea_chain,ideas,trend,future,human,year
148
 
149
- async def get_paper_idea_experiment_references_info(self,paper):
150
  article = paper.article
151
  if not article:
152
  return None
153
  paper_content = self.reader.read_paper_content(article)
154
  prompt = get_deep_reference_prompt(paper_content,self.topic)
155
  messages = self.wrap_messages(prompt)
156
- response = await self.get_cheap_openai_response_async(messages)
157
  entities = extract(response,"entities")
158
  idea = extract(response,"idea")
159
  experiment = extract(response,"experiment")
160
  references = extract(response,"references")
161
  return idea,experiment,entities,references,paper.title
162
 
163
- async def get_article_idea_experiment_references_info(self,article):
164
  paper_content = self.reader.read_paper_content_with_ref(article)
165
  prompt = get_deep_reference_prompt(paper_content,self.topic)
166
  messages = self.wrap_messages(prompt)
167
- response = await self.get_cheap_openai_response_async(messages)
168
  entities = extract(response,"entities")
169
  idea = extract(response,"idea")
170
  experiment = extract(response,"experiment")
@@ -172,7 +133,7 @@ class DeepResearchAgent:
172
  return idea,experiment,entities,references
173
 
174
 
175
- async def deep_research_paper_with_chain(self,paper:Result):
176
  print(f"begin to deep research paper {paper.title}")
177
  article = paper.article
178
  if not article:
@@ -183,7 +144,7 @@ class DeepResearchAgent:
183
  experiments = []
184
  total_entities = []
185
  years = []
186
- idea,experiment,entities,references = await self.get_article_idea_experiment_references_info(article)
187
  try:
188
  references = json.loads(references)
189
  except:
@@ -200,7 +161,7 @@ class DeepResearchAgent:
200
  # search before
201
  while len(idea_chain)<self.max_chain_length:
202
  rerank_query = f"{self.topic} {current_title} {current_abstract}"
203
- citation_paper = await self.reader.search_related_paper_async(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
204
  if not citation_paper:
205
  print(f"failed to find citation paper for {current_title}")
206
  break
@@ -208,10 +169,10 @@ class DeepResearchAgent:
208
  abstract = citation_paper.abstract
209
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
210
  messages = self.wrap_messages(prompt)
211
- response = await self.get_openai_response_async(messages)
212
  relevant = extract(response,"relevant")
213
  if relevant != "0":
214
- result = await self.get_paper_idea_experiment_references_info(citation_paper)
215
  if not result:
216
  break
217
  idea,experiment,entities,_,_ = result
@@ -238,13 +199,13 @@ class DeepResearchAgent:
238
  references.pop(0)
239
  if reference in self.read_papers:
240
  continue
241
- search_paper = await self.reader.search_async(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
242
  if len(search_paper) > 0:
243
  s_p = search_paper[0]
244
  if s_p and s_p.title not in self.read_papers:
245
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
246
  messages = self.wrap_messages(prompt)
247
- response = await self.get_openai_response_async(messages)
248
  relevant = extract(response,"relevant")
249
  if relevant != "0" or len(idea_chain) < self.min_chain_length:
250
  article = s_p.article
@@ -257,7 +218,7 @@ class DeepResearchAgent:
257
 
258
  if not article:
259
  rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
260
- search_paper = await self.reader.search_related_paper_async(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
261
  if not search_paper:
262
  print(f"failed to find citation paper for {current_title}")
263
  continue
@@ -273,10 +234,10 @@ class DeepResearchAgent:
273
  if s_p and s_p.title not in self.read_papers:
274
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
275
  messages = self.wrap_messages(prompt)
276
- response = await self.get_openai_response_async(messages)
277
  relevant = extract(response,"relevant")
278
  if relevant == "1" or len(idea_chain) < self.min_chain_length:
279
- article = await s_p.article
280
  if not article:
281
  continue
282
  else:
@@ -290,7 +251,7 @@ class DeepResearchAgent:
290
  paper_content = self.reader.read_paper_content_with_ref(article)
291
  prompt = get_deep_reference_prompt(paper_content,self.topic)
292
  messages = self.wrap_messages(prompt)
293
- response = await self.get_cheap_openai_response_async(messages)
294
  idea = extract(response,"idea")
295
  references = extract(response,"references")
296
  experiment = extract(response,"experiment")
@@ -317,7 +278,7 @@ class DeepResearchAgent:
317
 
318
  prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
319
  messages = self.wrap_messages(prompt)
320
- response = await self.get_openai_response_async(messages)
321
  trend = extract(response,"trend")
322
 
323
  self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
@@ -326,26 +287,26 @@ class DeepResearchAgent:
326
  <entities> {{cleaned entities}}</entities>
327
  """
328
  messages = self.wrap_messages(prompt)
329
- response = await self.get_openai_response_async(messages)
330
  total_entities = extract(response,"entities")
331
  bad_case = []
332
  prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
333
  messages = self.wrap_messages(prompt)
334
- response = await self.get_openai_response_async(messages)
335
  future = extract(response,"future")
336
  human = extract(response,"human")
337
 
338
 
339
  prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
340
  messages = self.wrap_messages(prompt)
341
- response = await self.get_openai_response_async(messages)
342
  method = extract(response,"method")
343
  novelty = extract(response,"novelty")
344
  motivation = extract(response,"motivation")
345
  idea = {"motivation":motivation,"novelty":novelty,"method":method}
346
  prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
347
  messages = self.wrap_messages(prompt)
348
- response = await self.get_openai_response_async(messages)
349
  final_idea = extract(response,"final_idea")
350
 
351
  idea = final_idea
 
1
  import json
2
  import time
 
 
3
  from searcher import Result,SementicSearcher
4
  from LLM import openai_llm
5
  from prompts import *
 
15
  cheap_llm = get_llm("gpt-4o-mini")
16
  return main_llm,cheap_llm
17
 
18
+ def judge_idea(i,j,idea0,idea1,topic,llm):
19
  prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
20
  messages = [{"role":"user","content":prompt}]
21
+ response = llm.response(messages)
22
  novelty = extract(response,"novelty")
23
  relevance = extract(response,"relevance")
24
  significance = extract(response,"significance")
 
53
  def wrap_messages(self,prompt):
54
  return [{"role":"user","content":prompt}]
55
 
56
+ def get_openai_response(self,messages):
57
+ return self.llm.response(messages)
58
 
59
+ def get_cheap_openai_response(self,messages):
60
+ return self.cheap_llm.response(messages,max_tokens = 16000)
61
 
62
+ def get_search_query(self,topic = None,query=None):
63
  prompt = get_deep_search_query_prompt(topic,query)
64
  messages = self.wrap_messages(prompt)
65
+ response = self.get_openai_response(messages)
66
  search_query = extract(response,"queries")
67
  try:
68
  search_query = json.loads(search_query)
 
71
  search_query = [query]
72
  return search_query
73
 
74
+ def generate_idea_with_chain(self,topic):
75
  self.topic = topic
76
  print(f"begin to generate search query for {topic}")
77
+ search_query = self.get_search_query(topic=topic)
78
  papers = []
79
  for query in search_query:
80
  failed_query = []
81
  current_papers = []
82
  cnt = 0
83
  while len(current_papers) == 0 and cnt < 10:
84
+ paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
85
  if paper and len(paper) > 0 and paper[0]:
86
  self.read_papers.add(paper[0].title)
87
  current_papers.append(paper[0])
 
89
  failed_query.append(query)
90
  prompt = get_deep_rewrite_query_prompt(failed_query,topic)
91
  messages = self.wrap_messages(prompt)
92
+ new_query = self.get_openai_response(messages)
93
  new_query = extract(new_query,"query")
94
  print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
95
  query = new_query
 
102
  print(f"failed to generate idea {topic}")
103
  return None,None,None,None,None,None,None,None,None
104
 
105
+ idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
107
  print(f"successfully generated idea")
108
+ return idea,experiment,entities,idea_chain,idea,trend,future,human,year
109
 
110
+ def get_paper_idea_experiment_references_info(self,paper):
111
  article = paper.article
112
  if not article:
113
  return None
114
  paper_content = self.reader.read_paper_content(article)
115
  prompt = get_deep_reference_prompt(paper_content,self.topic)
116
  messages = self.wrap_messages(prompt)
117
+ response = self.get_cheap_openai_response(messages)
118
  entities = extract(response,"entities")
119
  idea = extract(response,"idea")
120
  experiment = extract(response,"experiment")
121
  references = extract(response,"references")
122
  return idea,experiment,entities,references,paper.title
123
 
124
+ def get_article_idea_experiment_references_info(self,article):
125
  paper_content = self.reader.read_paper_content_with_ref(article)
126
  prompt = get_deep_reference_prompt(paper_content,self.topic)
127
  messages = self.wrap_messages(prompt)
128
+ response = self.get_cheap_openai_response(messages)
129
  entities = extract(response,"entities")
130
  idea = extract(response,"idea")
131
  experiment = extract(response,"experiment")
 
133
  return idea,experiment,entities,references
134
 
135
 
136
+ def deep_research_paper_with_chain(self,paper:Result):
137
  print(f"begin to deep research paper {paper.title}")
138
  article = paper.article
139
  if not article:
 
144
  experiments = []
145
  total_entities = []
146
  years = []
147
+ idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article)
148
  try:
149
  references = json.loads(references)
150
  except:
 
161
  # search before
162
  while len(idea_chain)<self.max_chain_length:
163
  rerank_query = f"{self.topic} {current_title} {current_abstract}"
164
+ citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
165
  if not citation_paper:
166
  print(f"failed to find citation paper for {current_title}")
167
  break
 
169
  abstract = citation_paper.abstract
170
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
171
  messages = self.wrap_messages(prompt)
172
+ response = self.get_openai_response(messages)
173
  relevant = extract(response,"relevant")
174
  if relevant != "0":
175
+ result = self.get_paper_idea_experiment_references_info(citation_paper)
176
  if not result:
177
  break
178
  idea,experiment,entities,_,_ = result
 
199
  references.pop(0)
200
  if reference in self.read_papers:
201
  continue
202
+ search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
203
  if len(search_paper) > 0:
204
  s_p = search_paper[0]
205
  if s_p and s_p.title not in self.read_papers:
206
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
207
  messages = self.wrap_messages(prompt)
208
+ response = self.get_openai_response(messages)
209
  relevant = extract(response,"relevant")
210
  if relevant != "0" or len(idea_chain) < self.min_chain_length:
211
  article = s_p.article
 
218
 
219
  if not article:
220
  rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
221
+ search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
222
  if not search_paper:
223
  print(f"failed to find citation paper for {current_title}")
224
  continue
 
234
  if s_p and s_p.title not in self.read_papers:
235
  prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
236
  messages = self.wrap_messages(prompt)
237
+ response = self.get_openai_response(messages)
238
  relevant = extract(response,"relevant")
239
  if relevant == "1" or len(idea_chain) < self.min_chain_length:
240
+ article = s_p.article
241
  if not article:
242
  continue
243
  else:
 
251
  paper_content = self.reader.read_paper_content_with_ref(article)
252
  prompt = get_deep_reference_prompt(paper_content,self.topic)
253
  messages = self.wrap_messages(prompt)
254
+ response = self.get_cheap_openai_response(messages)
255
  idea = extract(response,"idea")
256
  references = extract(response,"references")
257
  experiment = extract(response,"experiment")
 
278
 
279
  prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
280
  messages = self.wrap_messages(prompt)
281
+ response = self.get_openai_response(messages)
282
  trend = extract(response,"trend")
283
 
284
  self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
 
287
  <entities> {{cleaned entities}}</entities>
288
  """
289
  messages = self.wrap_messages(prompt)
290
+ response = self.get_openai_response(messages)
291
  total_entities = extract(response,"entities")
292
  bad_case = []
293
  prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
294
  messages = self.wrap_messages(prompt)
295
+ response = self.get_openai_response(messages)
296
  future = extract(response,"future")
297
  human = extract(response,"human")
298
 
299
 
300
  prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
301
  messages = self.wrap_messages(prompt)
302
+ response = self.get_openai_response(messages)
303
  method = extract(response,"method")
304
  novelty = extract(response,"novelty")
305
  motivation = extract(response,"motivation")
306
  idea = {"motivation":motivation,"novelty":novelty,"method":method}
307
  prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
308
  messages = self.wrap_messages(prompt)
309
+ response = self.get_openai_response(messages)
310
  final_idea = extract(response,"final_idea")
311
 
312
  idea = final_idea
app.py CHANGED
@@ -332,7 +332,7 @@ def form_post(topic: str = Form(...)):
332
  main_llm, cheap_llm = get_llms()
333
  deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
334
  print(f"begin to generate idea of topic {topic}")
335
- idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = asyncio.run(deep_research_agent.generate_idea_with_chain(topic))
336
  idea_md = markdown.markdown(idea)
337
  # 更新每日回复次数
338
  reply_count += 1
 
332
  main_llm, cheap_llm = get_llms()
333
  deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
334
  print(f"begin to generate idea of topic {topic}")
335
+ idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = deep_research_agent.generate_idea_with_chain(topic)
336
  idea_md = markdown.markdown(idea)
337
  # 更新每日回复次数
338
  reply_count += 1
main.py CHANGED
@@ -1,8 +1,9 @@
1
- from agents import DeepResearchAgent,ReviewAgent,get_llms
2
  import asyncio
3
  import json
4
  import argparse
5
 
 
6
  if __name__ == '__main__':
7
 
8
  argparser = argparse.ArgumentParser()
@@ -21,18 +22,12 @@ if __name__ == '__main__':
21
  topic = args.topic
22
  anchor_paper_path = args.anchor_paper_path
23
 
24
-
25
- review_agent = ReviewAgent(save_file=args.save_file,llm=main_llm,cheap_llm=cheap_llm)
26
  deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
27
 
28
  print(f"begin to generate idea and experiment of topic {topic}")
29
- idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year= asyncio.run(deep_research_agent.generate_idea_with_chain(topic,anchor_paper_path))
30
- experiment = asyncio.run(deep_research_agent.generate_experiment(idea,related_experiments,entities))
31
-
32
- for i in range(args.improve_cnt):
33
- experiment = asyncio.run(deep_research_agent.improve_experiment(review_agent,idea,experiment,entities))
34
 
35
  print(f"succeed to generate idea and experiment of topic {topic}")
36
- res = {"idea":idea,"experiment":experiment,"related_experiments":related_experiments,"entities":entities,"idea_chain":idea_chain,"ideas":ideas,"trend":trend,"future":future,"year":year,"human":human}
37
  with open("result.json","w") as f:
38
  json.dump(res,f)
 
1
+ from agents import DeepResearchAgent,get_llms
2
  import asyncio
3
  import json
4
  import argparse
5
 
6
+
7
  if __name__ == '__main__':
8
 
9
  argparser = argparse.ArgumentParser()
 
22
  topic = args.topic
23
  anchor_paper_path = args.anchor_paper_path
24
 
 
 
25
  deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
26
 
27
  print(f"begin to generate idea and experiment of topic {topic}")
28
+ idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year= deep_research_agent.generate_idea_with_chain(topic)
 
 
 
 
29
 
30
  print(f"succeed to generate idea and experiment of topic {topic}")
31
+ res = {"idea":idea,"related_experiments":related_experiments,"entities":entities,"idea_chain":idea_chain,"ideas":ideas,"trend":trend,"future":future,"year":year,"human":human}
32
  with open("result.json","w") as f:
33
  json.dump(res,f)
searcher/sementic_search.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  import aiohttp
8
  import asyncio
9
  import numpy as np
10
-
11
 
12
  def get_content_between_a_b(start_tag, end_tag, text):
13
  extracted_text = ""
@@ -31,29 +31,6 @@ def extract(text, type):
31
  return text
32
  else:
33
  return ""
34
-
35
-
36
- async def fetch(url):
37
- await asyncio.sleep(1) # 异步的 sleep 而不是 time.sleep
38
- try:
39
- timeout = aiohttp.ClientTimeout(total=120)
40
- connector = aiohttp.TCPConnector(limit_per_host=10) # 使用连接池
41
- async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
42
- async with session.get(url) as response:
43
- if response.status == 200:
44
- content = await response.read() # Read the response content as bytes
45
- return content
46
- else:
47
- print(f"Failed to fetch the URL: {url} with status code: {response.status}")
48
- return None
49
- except aiohttp.ClientError as e: # 更具体的异常捕获
50
- print(f"An error occurred while fetching the URL: {url}")
51
- print(e)
52
- return None
53
- except Exception as e:
54
- print(f"An unexpected error occurred while fetching the URL: {url}")
55
- print(e)
56
- return None
57
 
58
  def download(url):
59
  try:
@@ -103,7 +80,7 @@ class SementicSearcher:
103
  def __init__(self, ban_paper = []) -> None:
104
  self.ban_paper = ban_paper
105
 
106
- async def search_papers_async(self, query, limit=5, offset=0, fields=["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citations.title","citations.abstract","citations.isOpenAccess","citations.openAccessPdf","citations.citationCount","citationCount","citations.year"],
107
  publicationDate=None, minCitationCount=0, year=None,
108
  publicationTypes=None, fieldsOfStudy=None):
109
  url = 'https://api.semanticscholar.org/graph/v1/paper/search'
@@ -124,7 +101,6 @@ class SementicSearcher:
124
  # Load the API key from the configuration file
125
  api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
126
  headers = {'x-api-key': api_key} if api_key else None
127
- await asyncio.sleep(0.5)
128
  try:
129
  filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
130
  response = requests.get(url, params=filtered_query_params, headers=headers)
@@ -135,7 +111,7 @@ class SementicSearcher:
135
  elif response.status_code == 429:
136
  time.sleep(1)
137
  print(f"Request failed with status code {response.status_code}: begin to retry")
138
- return await self.search_papers_async(query, limit, offset, fields, publicationDate, minCitationCount, year, publicationTypes, fieldsOfStudy)
139
  else:
140
  print(f"Request failed with status code {response.status_code}: {response.text}")
141
  return None
@@ -145,6 +121,23 @@ class SementicSearcher:
145
 
146
  def cal_cosine_similarity(self, vec1, vec2):
147
  return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  def read_arxiv_from_path(self, pdf_path):
150
  def is_pdf(binary_data):
@@ -163,97 +156,41 @@ class SementicSearcher:
163
  return None
164
  return article_dict
165
 
166
- async def get_paper_embbeding_and_score_async(self,query_embedding, paper,llm):
167
  paper_content = f"""
168
  Title: {paper['title']}
169
  Abstract: {paper['abstract']}
170
  """
171
- paper_embbeding = await llm.get_embbeding_async(paper_content)
172
  paper_embbeding = np.array(paper_embbeding)
173
  score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
174
  return [paper,score]
175
 
176
 
177
- async def rerank_papers_async(self, query_embedding, paper_list,llm):
 
 
 
178
  if len(paper_list) >= 50:
179
- paper_list = paper_list[:50]
180
- results = await asyncio.gather(*[self.get_paper_embbeding_and_score_async(query_embedding, paper,llm) for paper in paper_list if paper])
181
- reranked_papers = sorted(results,key = lambda x: x[1],reverse = True)
182
- return reranked_papers
183
-
184
- async def get_embbeding_and_score_async(self,query_embedding, text,llm):
185
- text_embbeding = await llm.get_embbeding_async(text)
186
- text_embbeding = np.array(text_embbeding)
187
- score = self.cal_cosine_similarity(query_embedding,text_embbeding)
188
- return score
189
-
190
- async def get_embbeding_and_score_from_texts_async(self,query_embedding, texts,llm):
191
- results = await asyncio.gather(*[self.get_embbeding_and_score_async(query_embedding, text,llm) for text in texts])
192
- return results
193
-
194
- async def get_paper_details_async(self, paper_id, fields = ["title", "abstract", "year","citationCount","isOpenAccess","openAccessPdf"]):
195
- url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}'
196
- fields = process_fields(fields)
197
- paper_data_query_params = {'fields': fields}
198
- try:
199
- async with aiohttp.ClientSession() as session:
200
- filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
201
- headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
202
- async with session.get(url, params=filtered_query_params, headers=headers) as response:
203
- if response.status == 200:
204
- response_data = await response.json()
205
- return response_data
206
- else:
207
- await asyncio.sleep(0.01)
208
- print(f"Request failed with status code {response.status}: {await response.text()}")
209
- return None
210
- except Exception as e:
211
- print(f"Failed to get paper details for paper ID: {paper_id}")
212
- return None
213
-
214
- async def batch_retrieve_papers_async(self, paper_ids, fields = semantic_fields):
215
- url = 'https://api.semanticscholar.org/graph/v1/paper/batch'
216
- paper_data_query_params = {'fields': process_fields(fields)}
217
- paper_ids_json = {"ids": paper_ids}
218
- try:
219
- async with aiohttp.ClientSession() as session:
220
- filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
221
- headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
222
- async with session.post(url, json=paper_ids_json, params=filtered_query_params, headers=headers) as response:
223
- if response.status == 200:
224
- response_data = await response.json()
225
- return response_data
226
- else:
227
- await asyncio.sleep(0.01)
228
- print(f"Request failed with status code {response.status}: {await response.text()}")
229
- return None
230
- except Exception as e:
231
- print(f"Failed to batch retrieve papers for paper IDs: {paper_ids}")
232
- return None
233
-
234
- async def search_paper_from_title_async(self, query,fields = ["title","paperId"]):
235
- url = 'https://api.semanticscholar.org/graph/v1/paper/search/match'
236
- fields = process_fields(fields)
237
- query_params = {'query': query, 'fields': fields}
238
- try:
239
- async with aiohttp.ClientSession() as session:
240
- filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
241
- headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
242
- async with session.get(url, params=filtered_query_params, headers=headers) as response:
243
- if response.status == 200:
244
- response_data = await response.json()
245
- return response_data
246
- else:
247
- await asyncio.sleep(0.01)
248
- print(f"Request failed with status code {response.status}: {await response.text()}")
249
- return None
250
- except Exception as e:
251
- await asyncio.sleep(0.01)
252
- print(f"Failed to search paper from title: {query}")
253
- return None
254
 
255
 
256
- async def search_async(self,query,max_results = 5 ,paper_list = None ,rerank_query = None,llm = None,year = None,publicationDate = None,need_download = True,fields = ["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citationCount"]):
257
  if rerank_query:
258
  rerank_query_embbeding = llm.get_embbeding(rerank_query)
259
  rerank_query_embbeding = np.array(rerank_query_embbeding)
@@ -270,7 +207,7 @@ Abstract: {paper['abstract']}
270
  readed_papers = [paper.title for paper in paper_list]
271
 
272
  print(f"Searching for papers related to the query: <{query}>")
273
- results = await self.search_papers_async(query,limit = 10 * max_results,year=year,publicationDate = publicationDate,fields = fields)
274
  if not results or "data" not in results:
275
  return []
276
 
@@ -293,8 +230,7 @@ Abstract: {paper['abstract']}
293
  paper_candidates = results
294
 
295
  if llm and rerank_query:
296
- paper_candidates = await self.rerank_papers_async(rerank_query_embbeding, paper_candidates,llm)
297
- paper_candidates = [paper[0] for paper in paper_candidates if paper]
298
 
299
  if need_download:
300
  for result in paper_candidates:
@@ -326,10 +262,10 @@ Abstract: {paper['abstract']}
326
  break
327
  return final_results
328
 
329
- async def search_related_paper_async(self,title,need_citation = True,need_reference = True,rerank_query = None,llm = None,paper_list = []):
330
- print(f"Searching for the related papers of <{title}>")
331
  fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
332
- results = await self.search_papers_async(title,limit = 3,fields=fileds)
333
  related_papers = []
334
  related_papers_title = []
335
  if not results or "data" not in results:
@@ -367,8 +303,7 @@ Abstract: {paper['abstract']}
367
  if rerank_query and llm:
368
  rerank_query_embbeding = llm.get_embbeding(rerank_query)
369
  rerank_query_embbeding = np.array(rerank_query_embbeding)
370
- related_papers = await self.rerank_papers_async(rerank_query_embbeding, related_papers,llm)
371
- related_papers = [paper[0] for paper in related_papers]
372
  related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
373
  else:
374
  related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
@@ -385,13 +320,6 @@ Abstract: {paper['abstract']}
385
  return result
386
  return None
387
 
388
-
389
- async def download_pdf_async(self, pdf_link):
390
- content = await fetch(pdf_link)
391
- if not content:
392
- return None
393
- else:
394
- return content
395
 
396
  def download_pdf(self, pdf_link):
397
  content = download(pdf_link)
 
7
  import aiohttp
8
  import asyncio
9
  import numpy as np
10
+ import random
11
 
12
  def get_content_between_a_b(start_tag, end_tag, text):
13
  extracted_text = ""
 
31
  return text
32
  else:
33
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def download(url):
36
  try:
 
80
  def __init__(self, ban_paper = []) -> None:
81
  self.ban_paper = ban_paper
82
 
83
+ def search_papers(self, query, limit=5, offset=0, fields=["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citations.title","citations.abstract","citations.isOpenAccess","citations.openAccessPdf","citations.citationCount","citationCount","citations.year"],
84
  publicationDate=None, minCitationCount=0, year=None,
85
  publicationTypes=None, fieldsOfStudy=None):
86
  url = 'https://api.semanticscholar.org/graph/v1/paper/search'
 
101
  # Load the API key from the configuration file
102
  api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
103
  headers = {'x-api-key': api_key} if api_key else None
 
104
  try:
105
  filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
106
  response = requests.get(url, params=filtered_query_params, headers=headers)
 
111
  elif response.status_code == 429:
112
  time.sleep(1)
113
  print(f"Request failed with status code {response.status_code}: begin to retry")
114
+ return self.search_papers(query, limit, offset, fields, publicationDate, minCitationCount, year, publicationTypes, fieldsOfStudy)
115
  else:
116
  print(f"Request failed with status code {response.status_code}: {response.text}")
117
  return None
 
121
 
122
  def cal_cosine_similarity(self, vec1, vec2):
123
  return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
124
+
125
+ def cal_cosine_similarity_matric(self,matric1, matric2):
126
+ if isinstance(matric1, list):
127
+ matric1 = np.array(matric1)
128
+ if isinstance(matric2, list):
129
+ matric2 = np.array(matric2)
130
+ if len(matric1.shape) == 1:
131
+ matric1 = matric1.reshape(1, -1)
132
+ if len(matric2.shape) == 1:
133
+ matric2 = matric2.reshape(1, -1)
134
+ dot_product = np.dot(matric1, matric2.T)
135
+ norm1 = np.linalg.norm(matric1, axis=1)
136
+ norm2 = np.linalg.norm(matric2, axis=1)
137
+
138
+ cos_sim = dot_product / np.outer(norm1, norm2)
139
+ scores = cos_sim.flatten()
140
+ return scores.tolist()
141
 
142
  def read_arxiv_from_path(self, pdf_path):
143
  def is_pdf(binary_data):
 
156
  return None
157
  return article_dict
158
 
159
+ def get_paper_embbeding_and_score(self,query_embedding, paper,llm):
160
  paper_content = f"""
161
  Title: {paper['title']}
162
  Abstract: {paper['abstract']}
163
  """
164
+ paper_embbeding = llm.get_embbeding(paper_content)
165
  paper_embbeding = np.array(paper_embbeding)
166
  score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
167
  return [paper,score]
168
 
169
 
170
+ def rerank_papers(self, query_embedding, paper_list,llm):
171
+ if len(paper_list) == 0:
172
+ return []
173
+ paper_list = [paper for paper in paper_list if paper]
174
  if len(paper_list) >= 50:
175
+ paper_list = random.sample(paper_list,50)
176
+ paper_contents = []
177
+ for paper in paper_list:
178
+ paper_content = f"""
179
+ Title: {paper['title']}
180
+ Abstract: {paper['abstract']}
181
+ """
182
+ paper_contents.append(paper_content)
183
+ paper_contents_embbeding = llm.get_embbeding(paper_contents)
184
+ paper_contents_embbeding = np.array(paper_contents_embbeding)
185
+ scores = self.cal_cosine_similarity_matric(query_embedding,paper_contents_embbeding)
186
+
187
+ # 根据score对paper_list进行排序
188
+ paper_list = sorted(zip(paper_list,scores),key = lambda x: x[1],reverse = True)
189
+ paper_list = [paper[0] for paper in paper_list]
190
+ return paper_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
 
193
+ def search(self,query,max_results = 5 ,paper_list = None ,rerank_query = None,llm = None,year = None,publicationDate = None,need_download = True,fields = ["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citationCount"]):
194
  if rerank_query:
195
  rerank_query_embbeding = llm.get_embbeding(rerank_query)
196
  rerank_query_embbeding = np.array(rerank_query_embbeding)
 
207
  readed_papers = [paper.title for paper in paper_list]
208
 
209
  print(f"Searching for papers related to the query: <{query}>")
210
+ results = self.search_papers(query,limit = 10 * max_results,year=year,publicationDate = publicationDate,fields = fields)
211
  if not results or "data" not in results:
212
  return []
213
 
 
230
  paper_candidates = results
231
 
232
  if llm and rerank_query:
233
+ paper_candidates = self.rerank_papers(rerank_query_embbeding, paper_candidates,llm)
 
234
 
235
  if need_download:
236
  for result in paper_candidates:
 
262
  break
263
  return final_results
264
 
265
+ def search_related_paper(self,title,need_citation = True,need_reference = True,rerank_query = None,llm = None,paper_list = []):
266
+ print(f"Searching for the related papers of <{title}>, need_citation: {need_citation}, need_reference: {need_reference}")
267
  fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
268
+ results = self.search_papers(title,limit = 3,fields=fileds)
269
  related_papers = []
270
  related_papers_title = []
271
  if not results or "data" not in results:
 
303
  if rerank_query and llm:
304
  rerank_query_embbeding = llm.get_embbeding(rerank_query)
305
  rerank_query_embbeding = np.array(rerank_query_embbeding)
306
+ related_papers = self.rerank_papers(rerank_query_embbeding, related_papers,llm)
 
307
  related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
308
  else:
309
  related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
 
320
  return result
321
  return None
322
 
 
 
 
 
 
 
 
323
 
324
  def download_pdf(self, pdf_link):
325
  content = download(pdf_link)