pradachan commited on
Commit
f71c233
1 Parent(s): 701aa0b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +16 -0
  2. .gitignore +2 -0
  3. .ipynb_checkpoints/launch_scientist-checkpoint.py +500 -0
  4. DockerFile +0 -0
  5. LICENSE +201 -0
  6. Miniconda3-latest-Linux-x86_64.sh +3 -0
  7. README.md +312 -0
  8. ai_scientist/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  9. ai_scientist/.ipynb_checkpoints/generate_ideas-checkpoint.py +637 -0
  10. ai_scientist/.ipynb_checkpoints/llm-checkpoint.py +358 -0
  11. ai_scientist/.ipynb_checkpoints/perform_experiments-checkpoint.py +166 -0
  12. ai_scientist/.ipynb_checkpoints/perform_writeup-checkpoint.py +707 -0
  13. ai_scientist/Untitled.ipynb +62 -0
  14. ai_scientist/__init__.py +0 -0
  15. ai_scientist/__pycache__/__init__.cpython-311.pyc +0 -0
  16. ai_scientist/__pycache__/__init__.cpython-312.pyc +0 -0
  17. ai_scientist/__pycache__/generate_ideas.cpython-311.pyc +0 -0
  18. ai_scientist/__pycache__/generate_ideas.cpython-312.pyc +0 -0
  19. ai_scientist/__pycache__/llm.cpython-311.pyc +0 -0
  20. ai_scientist/__pycache__/llm.cpython-312.pyc +0 -0
  21. ai_scientist/__pycache__/perform_experiments.cpython-311.pyc +0 -0
  22. ai_scientist/__pycache__/perform_experiments.cpython-312.pyc +0 -0
  23. ai_scientist/__pycache__/perform_review.cpython-311.pyc +0 -0
  24. ai_scientist/__pycache__/perform_review.cpython-312.pyc +0 -0
  25. ai_scientist/__pycache__/perform_writeup.cpython-311.pyc +0 -0
  26. ai_scientist/__pycache__/perform_writeup.cpython-312.pyc +0 -0
  27. ai_scientist/fewshot_examples/132_automated_relational.json +3 -0
  28. ai_scientist/fewshot_examples/132_automated_relational.pdf +3 -0
  29. ai_scientist/fewshot_examples/132_automated_relational.txt +1190 -0
  30. ai_scientist/fewshot_examples/2_carpe_diem.json +3 -0
  31. ai_scientist/fewshot_examples/2_carpe_diem.pdf +0 -0
  32. ai_scientist/fewshot_examples/2_carpe_diem.txt +1035 -0
  33. ai_scientist/fewshot_examples/attention.json +3 -0
  34. ai_scientist/fewshot_examples/attention.pdf +0 -0
  35. ai_scientist/fewshot_examples/attention.txt +662 -0
  36. ai_scientist/generate_ideas.py +637 -0
  37. ai_scientist/llm.py +359 -0
  38. ai_scientist/perform_experiments.py +166 -0
  39. ai_scientist/perform_review.py +431 -0
  40. ai_scientist/perform_writeup.py +707 -0
  41. cuda-keyring_1.0-1_all.deb +0 -0
  42. cuda-keyring_1.0-1_all.deb.1 +0 -0
  43. cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb +3 -0
  44. data/enwik8/enwik8 +3 -0
  45. data/enwik8/enwik8.zip +3 -0
  46. data/enwik8/meta.pkl +3 -0
  47. data/enwik8/prepare.py +75 -0
  48. data/enwik8/test.bin +3 -0
  49. data/enwik8/train.bin +3 -0
  50. data/enwik8/val.bin +3 -0
.gitattributes CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Miniconda3-latest-Linux-x86_64.sh filter=lfs diff=lfs merge=lfs -text
37
+ ai_scientist/fewshot_examples/132_automated_relational.pdf filter=lfs diff=lfs merge=lfs -text
38
+ cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb filter=lfs diff=lfs merge=lfs -text
39
+ data/enwik8/enwik8 filter=lfs diff=lfs merge=lfs -text
40
+ data/text8/text8 filter=lfs diff=lfs merge=lfs -text
41
+ docs/adaptive_dual_scale_denoising.jpeg filter=lfs diff=lfs merge=lfs -text
42
+ example_papers/adaptive_dual_scale_denoising/adaptive_dual_scale_denoising.pdf filter=lfs diff=lfs merge=lfs -text
43
+ example_papers/adaptive_dual_scale_denoising.pdf filter=lfs diff=lfs merge=lfs -text
44
+ example_papers/data_augmentation_grokking/data_augmentation_grokking.pdf filter=lfs diff=lfs merge=lfs -text
45
+ example_papers/data_augmentation_grokking.pdf filter=lfs diff=lfs merge=lfs -text
46
+ example_papers/grid_based_noise_adaptation/grid_based_noise_adaptation.pdf filter=lfs diff=lfs merge=lfs -text
47
+ example_papers/grid_based_noise_adaptation.pdf filter=lfs diff=lfs merge=lfs -text
48
+ example_papers/layerwise_lr_grokking/layerwise_lr_grokking.pdf filter=lfs diff=lfs merge=lfs -text
49
+ example_papers/layerwise_lr_grokking.pdf filter=lfs diff=lfs merge=lfs -text
50
+ example_papers/weight_initialization_grokking/weight_initialization_grokking.pdf filter=lfs diff=lfs merge=lfs -text
51
+ example_papers/weight_initialization_grokking.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .aider*
2
+ .env
.ipynb_checkpoints/launch_scientist-checkpoint.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import multiprocessing
4
+ import os
5
+ import os.path as osp
6
+ import shutil
7
+ import sys
8
+ import time
9
+ from datetime import datetime
10
+
11
+ import openai
12
+ import torch
13
+ from aider.coders import Coder
14
+ from aider.io import InputOutput
15
+ from aider.models import Model
16
+
17
+ from ai_scientist.generate_ideas import check_idea_novelty, generate_ideas
18
+ from ai_scientist.llm import allchoices
19
+ from ai_scientist.perform_experiments import perform_experiments
20
+ from ai_scientist.perform_review import load_paper, perform_improvement, perform_review
21
+ from ai_scientist.perform_writeup import generate_latex, perform_writeup
22
+
23
+ NUM_REFLECTIONS = 3
24
+
25
+
26
+ def print_time():
27
+ print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
28
+
29
+
30
+ def parse_arguments():
31
+ parser = argparse.ArgumentParser(description="Run AI scientist experiments")
32
+ parser.add_argument(
33
+ "--skip-idea-generation",
34
+ action="store_true",
35
+ help="Skip idea generation and load existing ideas",
36
+ )
37
+ parser.add_argument(
38
+ "--skip-novelty-check",
39
+ action="store_true",
40
+ help="Skip novelty check and use existing ideas",
41
+ )
42
+ # add type of experiment (nanoGPT, Boston, etc.)
43
+ parser.add_argument(
44
+ "--experiment",
45
+ type=str,
46
+ default="nanoGPT_lite",
47
+ help="Experiment to run AI Scientist on.",
48
+ )
49
+ parser.add_argument(
50
+ "--model",
51
+ type=str,
52
+ default="Qwen/Qwen2.5-72B-Instruct",
53
+ choices=allchoices,
54
+ help="Model to use for AI Scientist.",
55
+ )
56
+ parser.add_argument(
57
+ "--writeup",
58
+ type=str,
59
+ default="latex",
60
+ choices=["latex"],
61
+ help="What format to use for writeup",
62
+ )
63
+ parser.add_argument(
64
+ "--parallel",
65
+ type=int,
66
+ default=0,
67
+ help="Number of parallel processes to run. 0 for sequential execution.",
68
+ )
69
+ parser.add_argument(
70
+ "--improvement",
71
+ action="store_true",
72
+ help="Improve based on reviews.",
73
+ )
74
+ parser.add_argument(
75
+ "--gpus",
76
+ type=str,
77
+ default=None,
78
+ help="Comma-separated list of GPU IDs to use (e.g., '0,1,2'). If not specified, all available GPUs will be used.",
79
+ )
80
+ parser.add_argument(
81
+ "--num-ideas",
82
+ type=int,
83
+ default=2,
84
+ help="Number of ideas to generate",
85
+ )
86
+ return parser.parse_args()
87
+
88
+
89
+ def get_available_gpus(gpu_ids=None):
90
+ if gpu_ids is not None:
91
+ return [int(gpu_id) for gpu_id in gpu_ids.split(",")]
92
+ return list(range(torch.cuda.device_count()))
93
+
94
+
95
+ def worker(
96
+ queue,
97
+ base_dir,
98
+ results_dir,
99
+ model,
100
+ client,
101
+ client_model,
102
+ writeup,
103
+ improvement,
104
+ gpu_id,
105
+ ):
106
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
107
+ print(f"Worker {gpu_id} started.")
108
+ while True:
109
+ idea = queue.get()
110
+ if idea is None:
111
+ break
112
+ success = do_idea(
113
+ base_dir,
114
+ results_dir,
115
+ idea,
116
+ model,
117
+ client,
118
+ client_model,
119
+ writeup,
120
+ improvement,
121
+ log_file=True,
122
+ )
123
+ print(f"Completed idea: {idea['Name']}, Success: {success}")
124
+ print(f"Worker {gpu_id} finished.")
125
+
126
+
127
+ def do_idea(
128
+ base_dir,
129
+ results_dir,
130
+ idea,
131
+ model,
132
+ client,
133
+ client_model,
134
+ writeup,
135
+ improvement,
136
+ log_file=False,
137
+ ):
138
+ ## CREATE PROJECT FOLDER
139
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
140
+ idea_name = f"{timestamp}_{idea['Name']}"
141
+ folder_name = osp.join(results_dir, idea_name)
142
+ assert not osp.exists(folder_name), f"Folder {folder_name} already exists."
143
+ destination_dir = folder_name
144
+ shutil.copytree(base_dir, destination_dir, dirs_exist_ok=True)
145
+ with open(osp.join(base_dir, "run_0", "final_info.json"), "r") as f:
146
+ baseline_results = json.load(f)
147
+ baseline_results = {k: v["means"] for k, v in baseline_results.items()}
148
+ exp_file = osp.join(folder_name, "experiment.py")
149
+ vis_file = osp.join(folder_name, "plot.py")
150
+ notes = osp.join(folder_name, "notes.txt")
151
+ with open(notes, "w") as f:
152
+ f.write(f"# Title: {idea['Title']}\n")
153
+ f.write(f"# Experiment description: {idea['Experiment']}\n")
154
+ f.write(f"## Run 0: Baseline\n")
155
+ f.write(f"Results: {baseline_results}\n")
156
+ f.write(f"Description: Baseline results.\n")
157
+ if log_file:
158
+ original_stdout = sys.stdout
159
+ original_stderr = sys.stderr
160
+ log_path = osp.join(folder_name, "log.txt")
161
+ log = open(log_path, "a")
162
+ sys.stdout = log
163
+ sys.stderr = log
164
+ try:
165
+ print_time()
166
+ print(f"*Starting idea: {idea_name}*")
167
+ ## PERFORM EXPERIMENTS
168
+ fnames = [exp_file, vis_file, notes]
169
+ io = InputOutput(
170
+ yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt"
171
+ )
172
+ if model == "hybrid":
173
+ main_model = Model("claude-3-5-sonnet-20240620")
174
+ elif model == "deepseek-coder-v2-0724":
175
+ main_model = Model("deepseek-ai/DeepSeek-V2.5")
176
+ elif model == "llama3.1-405b":
177
+ main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
178
+
179
+ # ----------------------------------------------------
180
+
181
+ elif args.model == "Qwen/Qwen2.5-72B-Instruct":
182
+ print("aider model chosen")
183
+
184
+ # main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
185
+ # main_model = Model("openai/Qwen2.5-72B-Instruct")
186
+ main_model = Model("friendli/Qwen2.5-72B-Instruct")
187
+
188
+ elif model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
189
+ main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
190
+
191
+ # ----------------------------------------------------
192
+
193
+ else:
194
+ main_model = Model(model)
195
+ coder = Coder.create(
196
+ main_model=main_model,
197
+ fnames=fnames,
198
+ io=io,
199
+ stream=False,
200
+ use_git=False,
201
+ edit_format="diff",
202
+ )
203
+
204
+ print_time()
205
+ print(f"*Starting Experiments*")
206
+ try:
207
+ success = perform_experiments(idea, folder_name, coder, baseline_results)
208
+ except Exception as e:
209
+ print(f"Error during experiments: {e}")
210
+ print(f"Experiments failed for idea {idea_name}")
211
+ return False
212
+
213
+ if not success:
214
+ print(f"Experiments failed for idea {idea_name}")
215
+ return False
216
+
217
+ print_time()
218
+ print(f"*Starting Writeup*")
219
+ ## PERFORM WRITEUP
220
+ if writeup == "latex":
221
+ writeup_file = osp.join(folder_name, "latex", "template.tex")
222
+ fnames = [exp_file, writeup_file, notes]
223
+ if model == "deepseek-coder-v2-0724":
224
+ main_model = Model("deepseek-ai/DeepSeek-V2.5")
225
+ elif model == "llama3.1-405b":
226
+ main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
227
+
228
+ # ----------------------------------------------------
229
+
230
+ elif args.model == "Qwen/Qwen2.5-72B-Instruct":
231
+ print("aider model chosen")
232
+ # main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
233
+ main_model = Model("openai/Qwen/Qwen2.5-72B-Instruct")
234
+
235
+ elif model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
236
+ main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
237
+
238
+ # ----------------------------------------------------
239
+ else:
240
+ main_model = Model(model)
241
+ coder = Coder.create(
242
+ main_model=main_model,
243
+ fnames=fnames,
244
+ io=io,
245
+ stream=False,
246
+ use_git=False,
247
+ edit_format="diff",
248
+ )
249
+ try:
250
+ perform_writeup(idea, folder_name, coder, client, client_model)
251
+ except Exception as e:
252
+ print(f"Failed to perform writeup: {e}")
253
+ return False
254
+ print("Done writeup")
255
+ else:
256
+ raise ValueError(f"Writeup format {writeup} not supported.")
257
+
258
+ print_time()
259
+ print(f"*Starting Review*")
260
+ ## REVIEW PAPER
261
+ if writeup == "latex":
262
+ try:
263
+ paper_text = load_paper(f"{folder_name}/{idea['Name']}.pdf")
264
+ if model == "gpt-4o-2024-05-13":
265
+ main_model = Model(model)
266
+ review = perform_review(
267
+ paper_text,
268
+ model=main_model,
269
+ client=openai.OpenAI(),
270
+ num_reflections=5,
271
+ num_fs_examples=1,
272
+ num_reviews_ensemble=5,
273
+ temperature=0.1,
274
+ )
275
+ elif model.startswith("ollama"):
276
+ # Use Ollama API for review generation
277
+ review = perform_review(
278
+ paper_text,
279
+ model=model.split("/")[-1],
280
+ client=openai.OpenAI(
281
+ api_key="ollama", base_url="http://localhost:11434/v1"
282
+ ),
283
+ num_reflections=5,
284
+ num_fs_examples=1,
285
+ num_reviews_ensemble=5,
286
+ temperature=0.1,
287
+ )
288
+ # Store the review in separate review.txt file
289
+ with open(osp.join(folder_name, "review.txt"), "w") as f:
290
+ f.write(json.dumps(review, indent=4))
291
+ except Exception as e:
292
+ print(f"Failed to perform review: {e}")
293
+ return False
294
+
295
+ ## IMPROVE WRITEUP
296
+ if writeup == "latex" and improvement:
297
+ print_time()
298
+ print(f"*Starting Improvement*")
299
+ try:
300
+ perform_improvement(review, coder)
301
+ generate_latex(
302
+ coder, folder_name, f"{folder_name}/{idea['Name']}_improved.pdf"
303
+ )
304
+ paper_text = load_paper(f"{folder_name}/{idea['Name']}_improved.pdf")
305
+
306
+ if model == "gpt-4o-2024-05-13":
307
+ main_model = Model(model)
308
+ review = perform_review(
309
+ paper_text,
310
+ model=main_model,
311
+ client=openai.OpenAI(),
312
+ num_reflections=5,
313
+ num_fs_examples=1,
314
+ num_reviews_ensemble=5,
315
+ temperature=0.1,
316
+ )
317
+ elif model.startswith("ollama"):
318
+ # Use Ollama API for review generation
319
+ review = perform_review(
320
+ paper_text,
321
+ model=model.split("/")[-1],
322
+ client=openai.OpenAI(
323
+ api_key="ollama", base_url="http://localhost:11434/v1"
324
+ ),
325
+ num_reflections=5,
326
+ num_fs_examples=1,
327
+ num_reviews_ensemble=5,
328
+ temperature=0.1,
329
+ )
330
+ # Store the review in separate review.txt file
331
+ with open(osp.join(folder_name, "review_improved.txt"), "w") as f:
332
+ f.write(json.dumps(review))
333
+ except Exception as e:
334
+ print(f"Failed to perform improvement: {e}")
335
+ return False
336
+ return True
337
+ except Exception as e:
338
+ print(f"Failed to evaluate idea {idea_name}: {str(e)}")
339
+ return False
340
+ finally:
341
+ print("FINISHED IDEA")
342
+ if log_file:
343
+ sys.stdout = original_stdout
344
+ sys.stderr = original_stderr
345
+ log.close()
346
+
347
+
348
+ if __name__ == "__main__":
349
+ import traceback
350
+ try:
351
+ args = parse_arguments()
352
+
353
+ # Check available GPUs and adjust parallel processes if necessary
354
+ available_gpus = get_available_gpus(args.gpus)
355
+ if args.parallel > len(available_gpus):
356
+ print(
357
+ f"Warning: Requested {args.parallel} parallel processes, but only {len(available_gpus)} GPUs available. Adjusting to {len(available_gpus)}."
358
+ )
359
+ args.parallel = len(available_gpus)
360
+
361
+ print(f"Using GPUs: {available_gpus}")
362
+
363
+ # Create client
364
+ if args.model == "claude-3-5-sonnet-20240620":
365
+ import anthropic
366
+
367
+ print(f"Using Anthropic API with model {args.model}.")
368
+ client_model = "claude-3-5-sonnet-20240620"
369
+ client = anthropic.Anthropic()
370
+ elif args.model.startswith("bedrock") and "claude" in args.model:
371
+ import anthropic
372
+
373
+ # Expects: bedrock/<MODEL_ID>
374
+ client_model = args.model.split("/")[-1]
375
+
376
+ print(f"Using Amazon Bedrock with model {client_model}.")
377
+ client = anthropic.AnthropicBedrock(
378
+ aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"),
379
+ aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
380
+ aws_region=os.getenv("AWS_REGION_NAME"),
381
+ )
382
+ elif args.model.startswith("vertex_ai") and "claude" in args.model:
383
+ import anthropic
384
+
385
+ # Expects: vertex_ai/<MODEL_ID>
386
+ client_model = args.model.split("/")[-1]
387
+
388
+ print(f"Using Vertex AI with model {client_model}.")
389
+ client = anthropic.AnthropicVertex()
390
+ elif args.model == "gpt-4o-2024-05-13":
391
+ import openai
392
+
393
+ print(f"Using OpenAI API with model {args.model}.")
394
+ client_model = "gpt-4o-2024-05-13"
395
+ client = openai.OpenAI()
396
+
397
+ # ----------------------------------------------------
398
+ elif args.model == "Qwen/Qwen2.5-72B-Instruct":
399
+ # elif args.model.startswith("hyperbolic"):
400
+ print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
401
+
402
+ import openai
403
+ import os
404
+ # client_model = args.model[11:]
405
+ client_model = args.model
406
+ client = openai.OpenAI(
407
+ api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
408
+ )
409
+ # ----------------------------------------------------
410
+
411
+ elif args.model.startswith("ollama"):
412
+ import openai
413
+
414
+ print(f"Using Ollama with {args.model}.")
415
+ client_model = args.model.split("/")[-1]
416
+ client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
417
+ else:
418
+ raise ValueError(f"Model {args.model} not supported.")
419
+
420
+ base_dir = osp.join("templates", args.experiment)
421
+ results_dir = osp.join("results", args.experiment)
422
+ ideas = generate_ideas(
423
+ base_dir,
424
+ client=client,
425
+ model=client_model,
426
+ skip_generation=args.skip_idea_generation,
427
+ max_num_generations=args.num_ideas,
428
+ num_reflections=NUM_REFLECTIONS,
429
+ )
430
+ ideas = check_idea_novelty(
431
+ ideas,
432
+ base_dir=base_dir,
433
+ client=client,
434
+ model=client_model,
435
+ )
436
+
437
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
438
+ json.dump(ideas, f, indent=4)
439
+
440
+ novel_ideas = [idea for idea in ideas if idea["novel"]]
441
+ # novel_ideas = list(reversed(novel_ideas))
442
+
443
+ if args.parallel > 0:
444
+ print(f"Running {args.parallel} parallel processes")
445
+ queue = multiprocessing.Queue()
446
+ for idea in novel_ideas:
447
+ queue.put(idea)
448
+
449
+ processes = []
450
+ for i in range(args.parallel):
451
+ gpu_id = available_gpus[i % len(available_gpus)]
452
+ p = multiprocessing.Process(
453
+ target=worker,
454
+ args=(
455
+ queue,
456
+ base_dir,
457
+ results_dir,
458
+ args.model,
459
+ client,
460
+ client_model,
461
+ args.writeup,
462
+ args.improvement,
463
+ gpu_id,
464
+ ),
465
+ )
466
+ p.start()
467
+ time.sleep(150)
468
+ processes.append(p)
469
+
470
+ # Signal workers to exit
471
+ for _ in range(args.parallel):
472
+ queue.put(None)
473
+
474
+ for p in processes:
475
+ p.join()
476
+
477
+ print("All parallel processes completed.")
478
+ else:
479
+ for idea in novel_ideas:
480
+ print(f"Processing idea: {idea['Name']}")
481
+ try:
482
+ success = do_idea(
483
+ base_dir,
484
+ results_dir,
485
+ idea,
486
+ args.model,
487
+ client,
488
+ client_model,
489
+ args.writeup,
490
+ args.improvement,
491
+ )
492
+ print(f"Completed idea: {idea['Name']}, Success: {success}")
493
+ except Exception as e:
494
+ print(f"Failed to evaluate idea {idea['Name']}: {str(e)}")
495
+
496
+ print("All ideas evaluated.")
497
+
498
+ except Exception as e:
499
+ print("error aya re baba")
500
+ traceback.print_exc()
DockerFile ADDED
File without changes
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 Rémi Louf
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Miniconda3-latest-Linux-x86_64.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d936ba600300e08eca3d874dee88c61c6f39303597b2b66baee54af4f7b4122
3
+ size 148337011
README.md ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">
2
+ <a href="https://github.com/SakanaAI/AI-Scientist/blob/main/docs/logo_2.png">
3
+ <img src="docs/logo_2.png" width="215" /></a><br>
4
+ <b>The AI Scientist: Towards Fully Automated</b><br>
5
+ <b>Open-Ended Scientific Discovery 🧑‍🔬</b><br>
6
+ </h1>
7
+
8
+ <p align="center">
9
+ 📚 <a href="https://arxiv.org/abs/2408.06292">[Paper]</a> |
10
+ 📝 <a href="https://sakana.ai/ai-scientist/">[Blog Post]</a> |
11
+ 📂 <a href="https://drive.google.com/drive/folders/1G7A0wTqfXVa-cpexjk0oaXakaSJwffEt">[Drive Folder]</a>
12
+ </p>
13
+
14
+ One of the grand challenges of artificial intelligence is developing agents capable of conducting scientific research and discovering new knowledge. While frontier models have already been used to aid human scientists, e.g. for brainstorming ideas or writing code, they still require extensive manual supervision or are heavily constrained to a specific task.
15
+
16
+ We're excited to introduce **The AI Scientist**, the first comprehensive system for fully automatic scientific discovery, enabling Foundation Models such as Large Language Models (LLMs) to perform research independently.
17
+
18
+ We further provide all runs and data from our paper [here](https://drive.google.com/drive/folders/1G7A0wTqfXVa-cpexjk0oaXakaSJwffEt?usp=sharing), where we run each base model on each template for ~50 ideas. We _highly_ recommend reading through some of the [Claude papers](https://drive.google.com/drive/folders/1Mmpz6M1FK4q8e-SewgZcUzdeD0Q2zC39?usp=sharing), (especially the diffusion ones), to get a sense of its strengths and weaknesses. Here are some example papers generated by **The AI Scientist** 📝:
19
+
20
+ 1. [DualScale Diffusion: Adaptive Feature Balancing for Low-Dimensional Generative Models](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/adaptive_dual_scale_denoising.pdf)
21
+ 2. [Multi-scale Grid Noise Adaptation: Enhancing Diffusion Models For Low-dimensional Data](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/grid_based_noise_adaptation.pdf)
22
+ 3. [GAN-Enhanced Diffusion: Boosting Sample Quality and Diversity](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/gan_diffusion.pdf)
23
+ 4. [DualDiff: Enhancing Mode Capture in Low-dimensional Diffusion Models via Dual-expert Denoising](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/dual_expert_denoiser.pdf)
24
+ 5. [StyleFusion: Adaptive Multi-style Generation in Character-Level Language Models](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/multi_style_adapter.pdf)
25
+ 6. [Adaptive Learning Rates for Transformers via Q-Learning](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/rl_lr_adaptation.pdf)
26
+ 7. [Unlocking Grokking: A Comparative Study of Weight Initialization Strategies in Transformer Models](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/weight_initialization_grokking.pdf)
27
+ 8. [Grokking Accelerated: Layer-wise Learning Rates for Transformer Generalization](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/layerwise_lr_grokking.pdf)
28
+ 9. [Grokking Through Compression: Unveiling Sudden Generalization via Minimal Description Length](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/mdl_grokking_correlation.pdf)
29
+ 10. [Accelerating Mathematical Insight: Boosting Grokking Through Strategic Data Augmentation](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/data_augmentation_grokking.pdf)
30
+
31
+ **Note**: Caution! This codebase will execute LLM-written code. There are various risks and challenges associated with this autonomy. This includes e.g. the use of potentially dangerous packages, web access, and potential spawning of processes. Use at your own discretion. Please make sure to [containerize](#containerization) and restrict web access appropriately.
32
+
33
+ <p align="center">
34
+ <a href="https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/adaptive_dual_scale_denoising/adaptive_dual_scale_denoising.pdf"><img src="https://github.com/SakanaAI/AI-Scientist/blob/main/docs/anim-ai-scientist.gif" alt="Adaptive Dual Scale Denoising" width="80%" />
35
+ </p>
36
+
37
+ ## Table of Contents
38
+
39
+ 1. [Requirements](#requirements)
40
+ 2. [Run AI Scientist Paper Generation Experiments](#run-ai-scientist-paper-generation-experiments)
41
+ 3. [Getting an LLM Generated Paper Review](#getting-an-llm-generated-paper-review)
42
+ 4. [Making your own Template](#making-your-own-template)
43
+ 5. [Template Resources](#template-resources)
44
+ 6. [Citing The AI Scientist](#citing-the-ai-scientist)
45
+ 7. [Frequently Asked Questions](#faq)
46
+ 8. [Containerization](#containerization)
47
+
48
+ ## Requirements
49
+
50
+ This code was designed for NVIDIA GPUs with CUDA using PyTorch. Support for other GPU architectures may be possible by following [PyTorch guidelines](https://pytorch.org/get-started/locally/). Current templates would likely take an infeasible amount of time on CPU-only machines. All code is designed to be run on Linux, other operating systems will likely require major adjustments.
51
+
52
+ ### Installation
53
+
54
+ ```bash
55
+ conda create -n ai_scientist python=3.11
56
+ conda activate ai_scientist
57
+ # Install pdflatex
58
+ sudo apt-get install texlive-full
59
+
60
+ # Install pypi requirements
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ When installing `texlive-full`, you may need to [hold Enter](https://askubuntu.com/questions/956006/pregenerating-context-markiv-format-this-may-take-some-time-takes-forever).
65
+
66
+ ### Supported Models and API Keys
67
+
68
+ We support a wide variety of models including open-weight and API-only models. In general, we recommend only using frontier models above the capability of the original GPT-4.
69
+
70
+ #### OpenAI API (GPT-4)
71
+
72
+ By default, this uses the `OPENAI_API_KEY` environment variable.
73
+
74
+ #### Anthropic API (Claude Sonnet 3.5)
75
+
76
+ By default, this uses the `ANTHROPIC_API_KEY` environment variable.
77
+
78
+ ##### Claude models via Bedrock
79
+
80
+ For Claude models provided by [Amazon Bedrock](https://aws.amazon.com/bedrock/), please install these additional packages:
81
+
82
+ ```bash
83
+ pip install anthropic[bedrock]
84
+ ```
85
+
86
+ Next, specify a set of valid [AWS Credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html) and the target [AWS Region](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-regions.html):
87
+
88
+ Set these environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION_NAME`.
89
+
90
+ ##### Claude models via Vertex AI
91
+
92
+ For Claude models provided by [Vertex AI Model Garden](https://cloud.google.com/model-garden?hl=en), please install these additional packages:
93
+
94
+ ```bash
95
+ pip install google-cloud-aiplatform
96
+ pip install anthropic[vertex]
97
+ ```
98
+
99
+ Next, set up a valid authentication for a [Google Cloud project](https://cloud.google.com/vertex-ai/docs/authentication), for example by providing region and project ID like so:
100
+
101
+ ```bash
102
+ export CLOUD_ML_REGION="REGION" # for Model Garden call
103
+ export ANTHROPIC_VERTEX_PROJECT_ID="PROJECT_ID" # for Model Garden call
104
+ export VERTEXAI_LOCATION="REGION" # for Aider/LiteLLM call, as per https://docs.litellm.ai/docs/providers/vertex#set-vertex-project--vertex-location
105
+ export VERTEXAI_PROJECT="PROJECT_ID" # for Aider/LiteLLM call as per https://docs.litellm.ai/docs/providers/vertex#set-vertex-project--vertex-location
106
+ ```
107
+
108
+ #### DeepSeek API (DeepSeek-Coder-V2)
109
+
110
+ By default, this uses the `DEEPSEEK_API_KEY` environment variable.
111
+
112
+ #### OpenRouter API (Llama3.1)
113
+
114
+ By default, this uses the `OPENROUTER_API_KEY` environment variable.
115
+
116
+ #### Semantic Scholar API (Literature Search)
117
+
118
+ Our code can also optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher throughput [if you have one](https://www.semanticscholar.org/product/api), though in principle it should work without it.
119
+
120
+ Be sure to provide the key for the model used for your runs, e.g.
121
+
122
+ ```bash
123
+ export OPENAI_API_KEY="YOUR KEY HERE"
124
+ export S2_API_KEY="YOUR KEY HERE"
125
+ ```
126
+
127
+ ### Setup NanoGPT
128
+
129
+ Here, and below, we give instructions for setting up the data and baseline evaluations for each template. You can only run setup steps for templates you are interested in. This is necessary to run on your machine as training times may vary depending on your hardware.
130
+
131
+ ```bash
132
+ # Prepare NanoGPT data
133
+ python data/enwik8/prepare.py
134
+ python data/shakespeare_char/prepare.py
135
+ python data/text8/prepare.py
136
+ ```
137
+
138
+ #### Create baseline runs (machine dependent)
139
+
140
+ ```bash
141
+ # Set up NanoGPT baseline run
142
+ # NOTE: YOU MUST FIRST RUN THE PREPARE SCRIPTS ABOVE!
143
+ cd templates/nanoGPT && python experiment.py --out_dir run_0 && python plot.py
144
+ ```
145
+
146
+ #### Create NanoGPT_lite baseline run. We use this for sanity-checking
147
+
148
+ ```bash
149
+ # NOTE: YOU MUST FIRST RUN THE PREPARE SCRIPTS ABOVE!
150
+ cd templates/nanoGPT_lite && python experiment.py --out_dir run_0 && python plot.py
151
+ ```
152
+
153
+ ### Setup 2D Diffusion
154
+
155
+ ```bash
156
+ # Set up 2D Diffusion
157
+ git clone https://github.com/gregversteeg/NPEET.git
158
+ cd NPEET
159
+ pip install .
160
+ pip install scikit-learn
161
+
162
+ # Set up 2D Diffusion baseline run
163
+ cd templates/2d_diffusion && python experiment.py --out_dir run_0 && python plot.py
164
+ ```
165
+
166
+ ### Setup Grokking
167
+
168
+ ```bash
169
+ # Set up Grokking
170
+ pip install einops
171
+
172
+ # Set up Grokking baseline run
173
+ cd templates/grokking && python experiment.py --out_dir run_0 && python plot.py
174
+ ```
175
+
176
+ ## Run AI Scientist Paper Generation Experiments
177
+
178
+ **Note:** please ensure the setup steps above are completed.
179
+
180
+ ```bash
181
+ conda activate ai_scientist
182
+ # Run the paper generation.
183
+ python launch_scientist.py --model "gpt-4o-2024-05-13" --experiment nanoGPT_lite --num-ideas 2
184
+ python launch_scientist.py --model "claude-3-5-sonnet-20240620" --experiment nanoGPT_lite --num-ideas 2
185
+ python launch_scientist.py --model "ollama/mistral-nemo" --experiment nanoGPT_lite --num-ideas 2
186
+ ```
187
+
188
+ If you have more than 1 GPU, use the `parallel` option to parallelize ideas across multiple GPUs.
189
+
190
+ ## Getting an LLM Generated Paper Review
191
+
192
+ ```python
193
+ import openai
194
+ from ai_scientist.perform_review import load_paper, perform_review
195
+
196
+ client = openai.OpenAI()
197
+ model = "gpt-4o-2024-05-13"
198
+
199
+ # Load paper from pdf file (raw text)
200
+ paper_txt = load_paper("report.pdf")
201
+ # Get the review dict of the review
202
+ review = perform_review(
203
+ paper_txt,
204
+ model,
205
+ client,
206
+ num_reflections=5,
207
+ num_fs_examples=1,
208
+ num_reviews_ensemble=5,
209
+ temperature=0.1,
210
+ )
211
+
212
+ # Inspect review results
213
+ review["Overall"] # overall score 1-10
214
+ review["Decision"] # ['Accept', 'Reject']
215
+ review["Weaknesses"] # List of weaknesses (str)
216
+ ```
217
+
218
+ To run batch analysis:
219
+
220
+ ```bash
221
+ cd review_iclr_bench
222
+ python iclr_analysis.py --num_reviews 500 --batch_size 100 --num_fs_examples 1 --num_reflections 5 --temperature 0.1 --num_reviews_ensemble 5
223
+ ```
224
+
225
+ ## Making your own Template
226
+
227
+ If there is an area of study you would like **The AI Scientist** to explore, it should be very easy to create your own templates. In general, follow the structure of the existing templates, which consists of:
228
+
229
+ - `experiment.py` -- This is a single file where the 'meat' of the content is. It takes in an argument for `out_dir`, which is where it should create the folder and save the relevant information from the run.
230
+ - `plot.py` -- This should take in the information from the `run` folders and create plots. The code should be clear and easy to edit.
231
+ - `prompt.json` -- Put information about your template here.
232
+ - `seed_ideas.json` -- Put example ideas here. You can also try to generate ideas without any examples, and then pick the best one or two to put here.
233
+ - `latex/template.tex` -- We recommend using our latex folder, but be sure to replace the pre-loaded citations with ones that you would expect to be more relevant.
234
+
235
+ ## Template Resources
236
+
237
+ We provide 3 templates, which heavily use code from other repositories, which we credit below. (Normally, we would do this in the files themselves, but it's unclear how this would affect The AI Scientist since it would be visible).
238
+
239
+ The NanoGPT template used code from [NanoGPT](https://github.com/karpathy/nanoGPT) and this [PR](https://github.com/karpathy/nanoGPT/pull/254).
240
+
241
+ The 2D Diffusion template used code from [tiny-diffusion](https://github.com/tanelp/tiny-diffusion), [ema-pytorch](https://github.com/lucidrains/ema-pytorch), and [Datasaur](https://www.research.autodesk.com/publications/same-stats-different-graphs/).
242
+
243
+ The Grokking template used code from [Sea-Snell/grokking](https://github.com/Sea-Snell/grokking) and [danielmamay/grokking](https://github.com/danielmamay/grokking).
244
+
245
+ We would like to thank the developers of the open-source models and packages for their contributions and for making their work available.
246
+
247
+ ## Citing The AI Scientist
248
+
249
+ If you use **The AI Scientist** in your research, please cite it as follows:
250
+
251
+ ```
252
+ @article{lu2024aiscientist,
253
+ title={The {AI} {S}cientist: Towards Fully Automated Open-Ended Scientific Discovery},
254
+ author={Lu, Chris and Lu, Cong and Lange, Robert Tjarko and Foerster, Jakob and Clune, Jeff and Ha, David},
255
+ journal={arXiv preprint arXiv:2408.06292},
256
+ year={2024}
257
+ }
258
+ ```
259
+
260
+ ## FAQ
261
+
262
+ We recommend reading our paper in the first instance for any questions you have on The AI Scientist.
263
+
264
+ ### Why am I missing files when running The AI Scientist?
265
+
266
+ Make sure you have completed all the setup and preparation steps before the main experiment script.
267
+
268
+ ### Why has a PDF or a review not been generated?
269
+
270
+ The AI Scientist finishes an idea with a success rate that depends on both the template, the base foundation model, and the complexity of the idea. We advise referring to our main paper. The highest success rates are observed with Claude Sonnet 3.5.
271
+ Reviews are best done with GPT-4o, all other models have issues with positivity bias or failure to conform to required outputs.
272
+
273
+ ### What is the cost of each idea generated?
274
+
275
+ Typically less than $15 per paper with Claude Sonnet 3.5. We recommend DeepSeek Coder V2 for a much more cost-effective approach. A good place to look for new models is the [Aider leaderboard](https://aider.chat/docs/leaderboards/).
276
+
277
+ ### How do I change the base conference format associated with the write-ups?
278
+
279
+ Change the base `template.tex` files contained within each template.
280
+
281
+ ### How do I run The AI Scientist for different subject fields?
282
+
283
+ Please refer to the instructions for different templates. In this current iteration, this is restricted to ideas that can be expressed in code. However, lifting this restriction would represent exciting future work! :)
284
+
285
+ ### How do I add support for a new foundation model?
286
+
287
+ Please see this [PR](https://github.com/SakanaAI/AI-Scientist/pull/7) for an example of how to add a new model, e.g. this time for Claude via Bedrock.
288
+ We do not advise any model that is significantly weaker than GPT-4 level for The AI Scientist.
289
+
290
+ ### Why do I need to run the baseline runs myself?
291
+ These appear as `run_0` and should be run per machine you execute The AI Scientist on for accurate run-time comparisons due to hardware differences.
292
+
293
+ ## Containerization
294
+
295
+ We include a [community-contributed](https://github.com/SakanaAI/AI-Scientist/pull/21) Docker image that may assist with your containerization efforts in `experimental/Dockerfile`.
296
+
297
+ You can use this image like this:
298
+
299
+ ```bash
300
+ # Endpoint Script
301
+ docker run -e OPENAI_API_KEY=$OPENAI_API_KEY -v `pwd`/templates:/app/AI-Scientist/templates <AI_SCIENTIST_IMAGE> \
302
+ --model gpt-4o-2024-05-13 \
303
+ --experiment 2d_diffusion \
304
+ --num-ideas 2
305
+ ```
306
+
307
+ ```bash
308
+ # Interactive
309
+ docker run -it -e OPENAI_API_KEY=$OPENAI_API_KEY \
310
+ --entrypoint /bin/bash \
311
+ <AI_SCIENTIST_IMAGE>
312
+ ```
ai_scientist/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
ai_scientist/.ipynb_checkpoints/generate_ideas-checkpoint.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from typing import Dict, List, Union
6
+
7
+ import backoff
8
+ import requests
9
+ from strictjson import strict_json
10
+
11
+ from ai_scientist.llm import (
12
+ allchoices,
13
+ extract_json_between_markers,
14
+ get_response_from_llm,
15
+ llm_json_auto_correct,
16
+ )
17
+
18
+ S2_API_KEY = os.getenv("S2_API_KEY")
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+
21
+
22
+ idea_first_prompt = """{task_description}
23
+ <experiment.py>
24
+ {code}
25
+ </experiment.py>
26
+
27
+ Here are the ideas that you have already generated:
28
+
29
+ '''
30
+ {prev_ideas_string}
31
+ '''
32
+
33
+ Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
34
+ Note that you will not have access to any additional resources or datasets.
35
+ Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
36
+
37
+ Respond in the following format:
38
+
39
+ THOUGHT:
40
+ <THOUGHT>
41
+
42
+ NEW IDEA JSON:
43
+ ```json
44
+ <JSON>
45
+ ```
46
+
47
+ In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
48
+
49
+ Add '```json' before the <JSON> and '```' after the <JSON> as above. In <JSON>, provide the new idea in JSON format with the following keys and values:
50
+ - "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
51
+ - "Title": A title for the idea, will be used for the report writing.
52
+ - "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
53
+ - "Interestingness": A rating from 1 to 10 (lowest to highest).
54
+ - "Feasibility": A rating from 1 to 10 (lowest to highest).
55
+ - "Novelty": A rating from 1 to 10 (lowest to highest).
56
+
57
+ Be cautious and realistic on your ratings.
58
+ This JSON will be automatically parsed, so ensure the format is precise.
59
+ You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
60
+ """
61
+
62
+ idea_reflection_prompt = """Round {current_round}/{num_reflections}.
63
+ In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
64
+ Include any other factors that you think are important in evaluating the idea.
65
+ Ensure the idea is clear and concise, and the JSON is the correct format.
66
+ Do not make things overly complicated.
67
+ In the next attempt, try and refine and improve your idea.
68
+ Stick to the spirit of the original idea unless there are glaring issues.
69
+
70
+ Respond in the exactly the same format as before:
71
+ THOUGHT:
72
+ <THOUGHT>
73
+
74
+ NEW IDEA JSON:
75
+ ```json
76
+ <JSON>
77
+ ```
78
+
79
+ If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
80
+ ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES.
81
+ """
82
+
83
+
84
+ # Format the content in JSON
85
+ def format_idea_json(text):
86
+ json_start_marker = "```json"
87
+ json_end_marker = "```"
88
+ start_index = text.find(json_start_marker)
89
+ if start_index != -1:
90
+ start_index += len(json_start_marker) # Move past the marker
91
+ end_index = text.find(json_end_marker, start_index)
92
+ json_string = text[start_index:end_index].strip()
93
+ res = strict_json(
94
+ system_prompt="You are a JSON formatter",
95
+ user_prompt=json_string,
96
+ output_format={
97
+ "Name": "A shortened descriptor of the idea",
98
+ "Title": "A title for the idea, will be used for the report writing",
99
+ "Experiment": "An outline of the implementation, type: list",
100
+ "Interestingness": "A rating from 1 to 10 (lowest to highest), type: int",
101
+ "Feasibility": "A rating from 1 to 10 (lowest to highest), type: int",
102
+ "Novelty": "A rating from 1 to 10 (lowest to highest), type: int",
103
+ },
104
+ llm=llm_json_auto_correct,
105
+ )
106
+ text = "```json\n" + json.dumps(res) + "```\n"
107
+ return text
108
+
109
+
110
+ def format_novelty_json(text):
111
+ json_start_marker = "```json"
112
+ json_end_marker = "```"
113
+ start_index = text.find(json_start_marker)
114
+ if start_index != -1:
115
+ start_index += len(json_start_marker) # Move past the marker
116
+ end_index = text.find(json_end_marker, start_index)
117
+ json_string = text[start_index:end_index].strip()
118
+ res = strict_json(
119
+ system_prompt="You are a JSON formatter",
120
+ user_prompt=json_string,
121
+ output_format={
122
+ "Query": "An optional search query to search the literature (e.g. attention is all you need)",
123
+ },
124
+ llm=llm_json_auto_correct,
125
+ )
126
+ text = "```json\n" + json.dumps(res) + "```\n"
127
+ return text
128
+
129
+
130
+ # GENERATE IDEAS
131
+ def generate_ideas(
132
+ base_dir,
133
+ client,
134
+ model,
135
+ skip_generation=False,
136
+ max_num_generations=20,
137
+ num_reflections=5,
138
+ ):
139
+ if skip_generation:
140
+ # Load existing ideas from file
141
+ try:
142
+ with open(osp.join(base_dir, "ideas.json"), "r") as f:
143
+ ideas = json.load(f)
144
+ print("Loaded existing ideas:")
145
+ for idea in ideas:
146
+ print(idea)
147
+ return ideas
148
+ except FileNotFoundError:
149
+ print("No existing ideas found. Generating new ideas.")
150
+ except json.JSONDecodeError:
151
+ print("Error decoding existing ideas. Generating new ideas.")
152
+
153
+ idea_str_archive = []
154
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
155
+ seed_ideas = json.load(f)
156
+ for seed_idea in seed_ideas:
157
+ idea_str_archive.append(json.dumps(seed_idea))
158
+
159
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
160
+ code = f.read()
161
+
162
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
163
+ prompt = json.load(f)
164
+
165
+ idea_system_prompt = prompt["system"]
166
+
167
+ for _ in range(max_num_generations):
168
+ print()
169
+ print(f"Generating idea {_ + 1}/{max_num_generations}")
170
+ import traceback
171
+ try:
172
+ prev_ideas_string = "\n\n".join(idea_str_archive)
173
+
174
+ msg_history = []
175
+ print(f"Iteration 1/{num_reflections}")
176
+ text, msg_history = get_response_from_llm(
177
+ idea_first_prompt.format(
178
+ task_description=prompt["task_description"],
179
+ code=code,
180
+ prev_ideas_string=prev_ideas_string,
181
+ num_reflections=num_reflections,
182
+ ),
183
+ client=client,
184
+ model=model,
185
+ system_message=idea_system_prompt,
186
+ msg_history=msg_history,
187
+ )
188
+ ## Format the content in JSON
189
+ text = format_idea_json(text)
190
+
191
+ ## PARSE OUTPUT
192
+ json_output = extract_json_between_markers(text)
193
+ assert json_output is not None, "Failed to extract JSON from LLM output"
194
+ # print(json_output)
195
+
196
+ # Iteratively improve task.
197
+ if num_reflections > 1:
198
+ for j in range(num_reflections - 1):
199
+ print(f"Iteration {j + 2}/{num_reflections}")
200
+ text, msg_history = get_response_from_llm(
201
+ idea_reflection_prompt.format(
202
+ current_round=j + 2, num_reflections=num_reflections
203
+ ),
204
+ client=client,
205
+ model=model,
206
+ system_message=idea_system_prompt,
207
+ msg_history=msg_history,
208
+ )
209
+ ## Format the content in JSON if using weak LLM
210
+ text = format_idea_json(text)
211
+ ## PARSE OUTPUT
212
+ json_output = extract_json_between_markers(text)
213
+ assert (
214
+ json_output is not None
215
+ ), "Failed to extract JSON from LLM output"
216
+ # print(json_output)
217
+
218
+ if "I am done" in text:
219
+ print(f"Idea generation converged after {j + 2} iterations.")
220
+ break
221
+
222
+ idea_str_archive.append(json.dumps(json_output))
223
+ except Exception as e:
224
+ print(f"Failed to generate idea: {e}")
225
+ traceback.print_exc()
226
+ continue
227
+
228
+ ## SAVE IDEAS
229
+ ideas = []
230
+ for idea_str in idea_str_archive:
231
+ ideas.append(json.loads(idea_str))
232
+
233
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
234
+ json.dump(ideas, f, indent=4)
235
+
236
+ return ideas
237
+
238
+
239
+ # GENERATE IDEAS OPEN-ENDED
240
+ def generate_next_idea(
241
+ base_dir,
242
+ client,
243
+ model,
244
+ prev_idea_archive=[],
245
+ num_reflections=5,
246
+ max_attempts=10,
247
+ ):
248
+ idea_archive = prev_idea_archive
249
+ original_archive_size = len(idea_archive)
250
+
251
+ print(f"Generating idea {original_archive_size + 1}")
252
+
253
+ if len(prev_idea_archive) == 0:
254
+ print(f"First iteration, taking seed ideas")
255
+ # seed the archive on the first run with pre-existing ideas
256
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
257
+ seed_ideas = json.load(f)
258
+ for seed_idea in seed_ideas[:1]:
259
+ idea_archive.append(seed_idea)
260
+ else:
261
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
262
+ code = f.read()
263
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
264
+ prompt = json.load(f)
265
+ idea_system_prompt = prompt["system"]
266
+
267
+ for _ in range(max_attempts):
268
+ import traceback
269
+ try:
270
+ idea_strings = []
271
+ for idea in idea_archive:
272
+ idea_strings.append(json.dumps(idea))
273
+ prev_ideas_string = "\n\n".join(idea_strings)
274
+
275
+ msg_history = []
276
+ print(f"Iteration 1/{num_reflections}")
277
+ text, msg_history = get_response_from_llm(
278
+ idea_first_prompt.format(
279
+ task_description=prompt["task_description"],
280
+ code=code,
281
+ prev_ideas_string=prev_ideas_string,
282
+ num_reflections=num_reflections,
283
+ )
284
+ + """
285
+ Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
286
+ This is on a standard 1-10 ML conference scale.
287
+ Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
288
+ """,
289
+ client=client,
290
+ model=model,
291
+ system_message=idea_system_prompt,
292
+ msg_history=msg_history,
293
+ )
294
+ ## Format the content in JSON if using weak LLM
295
+ text = format_idea_json(text)
296
+ ## PARSE OUTPUT
297
+ json_output = extract_json_between_markers(text)
298
+ assert json_output is not None, "Failed to extract JSON from LLM output"
299
+ # print(json_output)
300
+
301
+ # Iteratively improve task.
302
+ if num_reflections > 1:
303
+ for j in range(num_reflections - 1):
304
+ print(f"Iteration {j + 2}/{num_reflections}")
305
+ text, msg_history = get_response_from_llm(
306
+ idea_reflection_prompt.format(
307
+ current_round=j + 2, num_reflections=num_reflections
308
+ ),
309
+ client=client,
310
+ model=model,
311
+ system_message=idea_system_prompt,
312
+ msg_history=msg_history,
313
+ )
314
+ ## Format the content in JSON if using weak LLM
315
+ text = format_idea_json(text)
316
+ ## PARSE OUTPUT
317
+ json_output = extract_json_between_markers(text)
318
+ assert (
319
+ json_output is not None
320
+ ), "Failed to extract JSON from LLM output"
321
+ # print(json_output)
322
+
323
+ if "I am done" in text:
324
+ print(
325
+ f"Idea generation converged after {j + 2} iterations."
326
+ )
327
+ break
328
+
329
+ idea_archive.append(json_output)
330
+ break
331
+ except Exception as e:
332
+ print(f"Failed to generate idea: {e}")
333
+ traceback.print_exc()
334
+ continue
335
+
336
+ ## SAVE IDEAS
337
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
338
+ json.dump(idea_archive, f, indent=4)
339
+
340
+ return idea_archive
341
+
342
+
343
+ def on_backoff(details):
344
+ print(
345
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
346
+ f"calling function {details['target'].__name__} at {time.strftime('%X')}"
347
+ )
348
+
349
+
350
+ @backoff.on_exception(
351
+ backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
352
+ )
353
+ def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
354
+ if not query:
355
+ return None
356
+ rsp = requests.get(
357
+ "https://api.semanticscholar.org/graph/v1/paper/search",
358
+ headers={"X-API-KEY": S2_API_KEY},
359
+ params={
360
+ "query": query,
361
+ "limit": result_limit,
362
+ "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
363
+ },
364
+ )
365
+ print(f"Response Status Code: {rsp.status_code}")
366
+ print(
367
+ f"Response Content: {rsp.text[:500]}"
368
+ ) # Print the first 500 characters of the response content
369
+ rsp.raise_for_status()
370
+ results = rsp.json()
371
+ total = results["total"]
372
+ if not total:
373
+ return None
374
+ time.sleep(2)
375
+ papers = results["data"]
376
+ return papers
377
+
378
+
379
+ novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
380
+ You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
381
+ Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
382
+ You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
383
+ The top 10 results for any search query will be presented to you with the abstracts.
384
+
385
+ You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
386
+ At any round, you may exit early and decide on the novelty of the idea.
387
+ Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
388
+ Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
389
+
390
+ {task_description}
391
+ <experiment.py>
392
+ {code}
393
+ </experiment.py>
394
+ """
395
+
396
+ novelty_prompt = '''Round {current_round}/{num_rounds}.
397
+ You have this idea:
398
+
399
+ """
400
+ {idea}
401
+ """
402
+
403
+ The results of the last query are (empty on first round):
404
+ """
405
+ {last_query_results}
406
+ """
407
+
408
+ Respond in the following format:
409
+
410
+ THOUGHT:
411
+ <THOUGHT>
412
+
413
+ RESPONSE:
414
+ ```json
415
+ <JSON>
416
+ ```
417
+
418
+ In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision.
419
+ If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
420
+
421
+ In <JSON>, respond in JSON format with ONLY the following field:
422
+ - "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
423
+
424
+ A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
425
+ This JSON will be automatically parsed, so ensure the format is precise.
426
+ '''
427
+
428
+
429
+ def check_idea_novelty(
430
+ ideas,
431
+ base_dir,
432
+ client,
433
+ model,
434
+ max_num_iterations=10,
435
+ ):
436
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
437
+ code = f.read()
438
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
439
+ prompt = json.load(f)
440
+ task_description = prompt["task_description"]
441
+
442
+ for idx, idea in enumerate(ideas):
443
+ if "novel" in idea:
444
+ print(f"Skipping idea {idx}, already checked.")
445
+ continue
446
+
447
+ print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
448
+
449
+ novel = False
450
+ msg_history = []
451
+ papers_str = ""
452
+
453
+ for j in range(max_num_iterations):
454
+ try:
455
+ text, msg_history = get_response_from_llm(
456
+ novelty_prompt.format(
457
+ current_round=j + 1,
458
+ num_rounds=max_num_iterations,
459
+ idea=idea,
460
+ last_query_results=papers_str,
461
+ ),
462
+ client=client,
463
+ model=model,
464
+ system_message=novelty_system_msg.format(
465
+ num_rounds=max_num_iterations,
466
+ task_description=task_description,
467
+ code=code,
468
+ ),
469
+ msg_history=msg_history,
470
+ )
471
+ if "decision made: novel" in text.lower():
472
+ print("Decision made: novel after round", j)
473
+ novel = True
474
+ break
475
+ if "decision made: not novel" in text.lower():
476
+ print("Decision made: not novel after round", j)
477
+ break
478
+
479
+ ## Format the content in JSON
480
+ text = format_novelty_json(text)
481
+ print("text after formating\n", text)
482
+ ## PARSE OUTPUT
483
+ json_output = extract_json_between_markers(text)
484
+ assert json_output is not None, "Failed to extract JSON from LLM output"
485
+
486
+ ## SEARCH FOR PAPERS
487
+ query = json_output["Query"]
488
+ papers = search_for_papers(query, result_limit=10)
489
+ if papers is None:
490
+ papers_str = "No papers found."
491
+
492
+ paper_strings = []
493
+ for i, paper in enumerate(papers):
494
+ paper_strings.append(
495
+ """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
496
+ i=i,
497
+ title=paper["title"],
498
+ authors=paper["authors"],
499
+ venue=paper["venue"],
500
+ year=paper["year"],
501
+ cites=paper["citationCount"],
502
+ abstract=paper["abstract"],
503
+ )
504
+ )
505
+ papers_str = "\n\n".join(paper_strings)
506
+
507
+ except Exception as e:
508
+ print(f"Error: {e}")
509
+ continue
510
+
511
+ idea["novel"] = novel
512
+
513
+ # Save results to JSON file
514
+ results_file = osp.join(base_dir, "ideas.json")
515
+ with open(results_file, "w") as f:
516
+ json.dump(ideas, f, indent=4)
517
+
518
+ return ideas
519
+
520
+
521
+ if __name__ == "__main__":
522
+ MAX_NUM_GENERATIONS = 32
523
+ NUM_REFLECTIONS = 5
524
+ import argparse
525
+
526
+ parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
527
+ # add type of experiment (nanoGPT, Boston, etc.)
528
+ parser.add_argument(
529
+ "--experiment",
530
+ type=str,
531
+ default="nanoGPT",
532
+ help="Experiment to run AI Scientist on.",
533
+ )
534
+ parser.add_argument(
535
+ "--model",
536
+ type=str,
537
+ default="deepseek-ai/DeepSeek-V2.5",
538
+ choices=allchoices,
539
+ help="Model to use for AI Scientist.",
540
+ )
541
+ parser.add_argument(
542
+ "--skip-idea-generation",
543
+ action="store_true",
544
+ help="Skip idea generation and use existing ideas.",
545
+ )
546
+ parser.add_argument(
547
+ "--check-novelty",
548
+ action="store_true",
549
+ help="Check novelty of ideas.",
550
+ )
551
+ args = parser.parse_args()
552
+
553
+ # Create client
554
+
555
+ # ------------------------------------------------------------------------------------------------------
556
+
557
+ if args.model == "Qwen/Qwen2.5-72B-Instruct":
558
+ # elif args.model.startswith("hyperbolic"):
559
+ print(f"Welcome to the PARADISE of debug <generate_scientist.py> {args.model}.")
560
+
561
+ import openai
562
+ import os
563
+ # client_model = args.model[11:]
564
+ client_model = args.model
565
+ client = openai.OpenAI(
566
+ api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
567
+ )
568
+
569
+ # ------------------------------------------------------------------------------------------------------
570
+
571
+
572
+ elif args.model == "claude-3-5-sonnet-20240620":
573
+ import anthropic
574
+
575
+ print(f"Using Anthropic API with model {args.model}.")
576
+ client_model = "claude-3-5-sonnet-20240620"
577
+ client = anthropic.Anthropic()
578
+ elif args.model.startswith("bedrock") and "claude" in args.model:
579
+ import anthropic
580
+
581
+ # Expects: bedrock/<MODEL_ID>
582
+ client_model = args.model.split("/")[-1]
583
+
584
+ print(f"Using Amazon Bedrock with model {client_model}.")
585
+ client = anthropic.AnthropicBedrock()
586
+ elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
587
+ import openai
588
+
589
+ print(f"Using OpenAI API with model {args.model}.")
590
+ client_model = "gpt-4o-2024-05-13"
591
+ client = openai.OpenAI()
592
+ elif args.model == "deepseek-coder-v2-0724":
593
+ import openai
594
+
595
+ print(f"Using OpenAI API with {args.model}.")
596
+ client_model = "deepseek-coder-v2-0724"
597
+ client = openai.OpenAI(
598
+ api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
599
+ )
600
+ elif args.model == "llama3.1-405b":
601
+ import openai
602
+
603
+ print(f"Using OpenAI API with {args.model}.")
604
+ client_model = "meta-llama/llama-3.1-405b-instruct"
605
+ client = openai.OpenAI(
606
+ api_key=os.environ["OPENROUTER_API_KEY"],
607
+ base_url="https://openrouter.ai/api/v1",
608
+ )
609
+ elif args.model.startswith("ollama"):
610
+ import openai
611
+
612
+ print(f"Using Ollama with {args.model}.")
613
+ client_model = args.model.split("/")[-1]
614
+ # client_model = args.model
615
+ client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
616
+
617
+ else:
618
+ raise ValueError(f"Model {args.model} not supported.")
619
+
620
+ base_dir = osp.join("templates", args.experiment)
621
+ results_dir = osp.join("results", args.experiment)
622
+ print("going into line 623...")
623
+ ideas = generate_ideas(
624
+ base_dir,
625
+ client=client,
626
+ model=client_model,
627
+ skip_generation=args.skip_idea_generation,
628
+ max_num_generations=MAX_NUM_GENERATIONS,
629
+ num_reflections=NUM_REFLECTIONS,
630
+ )
631
+ if args.check_novelty:
632
+ ideas = check_idea_novelty(
633
+ ideas,
634
+ base_dir=base_dir,
635
+ client=client,
636
+ model=client_model,
637
+ )
ai_scientist/.ipynb_checkpoints/llm-checkpoint.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import backoff
4
+ import openai
5
+
6
+ # Ollama
7
+ ollama_choices = [
8
+ "mistral-nemo",
9
+ "llama3.1",
10
+ ]
11
+
12
+ # hyperbolic
13
+ hyperbolic_choices = [
14
+ "Qwen/Qwen2.5-72B-Instruct",
15
+ "meta-llama/Meta-Llama-3.1-70B-Instruct",
16
+ ]
17
+
18
+
19
+ allchoices = [
20
+ "Qwen/Qwen2.5-72B-Instruct",
21
+ "deepseek-ai/DeepSeek-V2.5",
22
+ "claude-3-5-sonnet-20240620",
23
+ "gpt-4o-2024-05-13",
24
+ "deepseek-coder-v2-0724",
25
+ "llama3.1-405b",
26
+ # Anthropic Claude models via Amazon Bedrock
27
+ "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
28
+ "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
29
+ "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
30
+ "bedrock/anthropic.claude-3-opus-20240229-v1:0",
31
+ ]
32
+
33
+ for item in ollama_choices:
34
+ allchoices.append("ollama/" + item)
35
+
36
+ for item in hyperbolic_choices:
37
+ allchoices.append("hyperbolic/" + item)
38
+
39
+
40
+ # Get N responses from a single message, used for ensembling.
41
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
42
+ def get_batch_responses_from_llm(
43
+ msg,
44
+ client,
45
+ model,
46
+ system_message,
47
+ print_debug=False,
48
+ msg_history=None,
49
+ temperature=0.75,
50
+ n_responses=1,
51
+ ):
52
+ if msg_history is None:
53
+ msg_history = []
54
+
55
+ if model in [
56
+ "gpt-4o-2024-05-13",
57
+ "gpt-4o-mini-2024-07-18",
58
+ "gpt-4o-2024-08-06",
59
+ "Qwen/Qwen2.5-72B-Instruct"
60
+ ]:
61
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
62
+ response = client.chat.completions.create(
63
+ model=model,
64
+ messages=[
65
+ {"role": "system", "content": system_message},
66
+ *new_msg_history,
67
+ ],
68
+ temperature=temperature,
69
+ max_tokens=3000,
70
+ n=n_responses,
71
+ stop=None,
72
+ seed=0,
73
+ )
74
+ content = [r.message.content for r in response.choices]
75
+ new_msg_history = [
76
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
77
+ ]
78
+ elif model == "deepseek-coder-v2-0724":
79
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
80
+ response = client.chat.completions.create(
81
+ model="deepseek-coder",
82
+ messages=[
83
+ {"role": "system", "content": system_message},
84
+ *new_msg_history,
85
+ ],
86
+ temperature=temperature,
87
+ max_tokens=3000,
88
+ n=n_responses,
89
+ stop=None,
90
+ )
91
+ content = [r.message.content for r in response.choices]
92
+ new_msg_history = [
93
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
94
+ ]
95
+
96
+ # ------------------------------------------------------------------------------------------------------
97
+
98
+ elif model == "Qwen/Qwen2.5-72B-Instruct":
99
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
100
+ response = client.chat.completions.create(
101
+ model="Qwen/Qwen2.5-72B-Instruct",
102
+ messages=[
103
+ {"role": "system", "content": system_message},
104
+ *new_msg_history,
105
+ ],
106
+ temperature=temperature,
107
+ max_tokens=3000,
108
+ n=n_responses,
109
+ stop=None,
110
+ )
111
+ content = [r.message.content for r in response.choices]
112
+ new_msg_history = [
113
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
114
+ ]
115
+
116
+ # elif model in hyperbolic_choices:
117
+ # content, new_msg_history = [], []
118
+ # for i in range(n_responses):
119
+ # print(f"Getting {i+1}/{n_responses} response from {model}")
120
+ # c, hist = get_response_from_llm(
121
+ # msg,
122
+ # client,
123
+ # model,
124
+ # system_message,
125
+ # print_debug=False,
126
+ # msg_history=None,
127
+ # temperature=temperature,
128
+ # )
129
+ # content.append(c)
130
+ # new_msg_history.append(hist)
131
+
132
+ # ------------------------------------------------------------------------------------------------------
133
+
134
+ elif model == "llama-3-1-405b-instruct":
135
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
136
+ response = client.chat.completions.create(
137
+ model="meta-llama/llama-3.1-405b-instruct",
138
+ messages=[
139
+ {"role": "system", "content": system_message},
140
+ *new_msg_history,
141
+ ],
142
+ temperature=temperature,
143
+ max_tokens=3000,
144
+ n=n_responses,
145
+ stop=None,
146
+ )
147
+ content = [r.message.content for r in response.choices]
148
+ new_msg_history = [
149
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
150
+ ]
151
+ elif model == "claude-3-5-sonnet-20240620":
152
+ content, new_msg_history = [], []
153
+ for _ in range(n_responses):
154
+ c, hist = get_response_from_llm(
155
+ msg,
156
+ client,
157
+ model,
158
+ system_message,
159
+ print_debug=False,
160
+ msg_history=None,
161
+ temperature=temperature,
162
+ )
163
+ content.append(c)
164
+ new_msg_history.append(hist)
165
+
166
+ # ollama models
167
+ elif model in ollama_choices:
168
+ content, new_msg_history = [], []
169
+ for i in range(n_responses):
170
+ print(f"Getting {i+1}/{n_responses} response from {model}")
171
+ c, hist = get_response_from_llm(
172
+ msg,
173
+ client,
174
+ model,
175
+ system_message,
176
+ print_debug=False,
177
+ msg_history=None,
178
+ temperature=temperature,
179
+ )
180
+ content.append(c)
181
+ new_msg_history.append(hist)
182
+ else:
183
+ raise ValueError(f"Model {model} not supported.")
184
+
185
+ if print_debug:
186
+ # Just print the first one.
187
+ print()
188
+ print("*" * 20 + " LLM START " + "*" * 20)
189
+ for j, msg in enumerate(new_msg_history[0]):
190
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
191
+ print(content)
192
+ print("*" * 21 + " LLM END " + "*" * 21)
193
+ print()
194
+
195
+ return content, new_msg_history
196
+
197
+
198
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
199
+ def get_response_from_llm(
200
+ msg,
201
+ client,
202
+ model,
203
+ system_message,
204
+ print_debug=False,
205
+ msg_history=None,
206
+ temperature=0.75,
207
+ ):
208
+ if msg_history is None:
209
+ msg_history = []
210
+
211
+ if model == "claude-3-5-sonnet-20240620":
212
+ new_msg_history = msg_history + [
213
+ {
214
+ "role": "user",
215
+ "content": [
216
+ {
217
+ "type": "text",
218
+ "text": msg,
219
+ }
220
+ ],
221
+ }
222
+ ]
223
+ response = client.messages.create(
224
+ model="claude-3-5-sonnet-20240620",
225
+ max_tokens=3000,
226
+ temperature=temperature,
227
+ system=system_message,
228
+ messages=new_msg_history,
229
+ )
230
+ content = response.content[0].text
231
+ new_msg_history = new_msg_history + [
232
+ {
233
+ "role": "assistant",
234
+ "content": [
235
+ {
236
+ "type": "text",
237
+ "text": content,
238
+ }
239
+ ],
240
+ }
241
+ ]
242
+ # ------------------------------------------------------------------------------------------------------
243
+
244
+ elif model in [
245
+ "gpt-4o-2024-05-13",
246
+ "gpt-4o-mini-2024-07-18",
247
+ "gpt-4o-2024-08-06",
248
+ "Qwen/Qwen2.5-72B-Instruct"
249
+ ]:
250
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
251
+ response = client.chat.completions.create(
252
+ model=model,
253
+ messages=[
254
+ {"role": "system", "content": system_message},
255
+ *new_msg_history,
256
+ ],
257
+ temperature=temperature,
258
+ max_tokens=3000,
259
+ n=1,
260
+ stop=None,
261
+ seed=0,
262
+ )
263
+ content = response.choices[0].message.content
264
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
265
+
266
+
267
+ # ------------------------------------------------------------------------------------------------------
268
+
269
+
270
+ elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
271
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
272
+ response = client.chat.completions.create(
273
+ model="meta-llama/llama-3.1-405b-instruct",
274
+ messages=[
275
+ {"role": "system", "content": system_message},
276
+ *new_msg_history,
277
+ ],
278
+ temperature=temperature,
279
+ max_tokens=3000,
280
+ n=1,
281
+ stop=None,
282
+ )
283
+ content = response.choices[0].message.content
284
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
285
+
286
+
287
+ elif model in ollama_choices:
288
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
289
+ response = client.chat.completions.create(
290
+ model=model,
291
+ messages=[
292
+ {"role": "system", "content": system_message},
293
+ *new_msg_history,
294
+ ],
295
+ temperature=temperature,
296
+ max_tokens=6000,
297
+ n=1,
298
+ stop=None,
299
+ seed=0,
300
+ )
301
+ content = response.choices[0].message.content
302
+ # print("\nget_response_from_llm\n")
303
+ # print(content)
304
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
305
+
306
+ else:
307
+ raise ValueError(f"Model {model} not supported.")
308
+
309
+ if print_debug:
310
+ print()
311
+ print("*" * 20 + " LLM START " + "*" * 20)
312
+ for j, msg in enumerate(new_msg_history):
313
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
314
+ print(content)
315
+ print("*" * 21 + " LLM END " + "*" * 21)
316
+ print()
317
+
318
+ return content, new_msg_history
319
+
320
+
321
+ def llm_json_auto_correct(system_prompt: str, user_prompt: str) -> str:
322
+ import os
323
+ client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1")
324
+ response = client.chat.completions.create(
325
+ model="Qwen/Qwen2.5-72B-Instruct",
326
+ temperature=0,
327
+ messages=[
328
+ {"role": "system", "content": system_prompt},
329
+ {"role": "user", "content": user_prompt},
330
+ ],
331
+ )
332
+ return response.choices[0].message.content
333
+
334
+
335
+ def extract_json_between_markers(llm_output):
336
+ json_start_marker = "```json"
337
+ json_end_marker = "```"
338
+
339
+ # Find the start and end indices of the JSON string
340
+ start_index = llm_output.find(json_start_marker)
341
+ if start_index != -1:
342
+ start_index += len(json_start_marker) # Move past the marker
343
+ end_index = llm_output.find(json_end_marker, start_index)
344
+ else:
345
+ return None # JSON markers not found
346
+
347
+ if end_index == -1:
348
+ return None # End marker not found
349
+
350
+ # Extract the JSON string
351
+ json_string = llm_output[start_index:end_index].strip()
352
+ # print(json_string)
353
+ try:
354
+ parsed_json = json.loads(json_string)
355
+
356
+ return parsed_json
357
+ except json.JSONDecodeError:
358
+ return None # Invalid JSON format
ai_scientist/.ipynb_checkpoints/perform_experiments-checkpoint.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os.path as osp
3
+ import subprocess
4
+ from subprocess import TimeoutExpired
5
+ import sys
6
+ import json
7
+
8
+ MAX_ITERS = 4
9
+ MAX_RUNS = 5
10
+ MAX_STDERR_OUTPUT = 1500
11
+
12
+ coder_prompt = """Your goal is to implement the following idea: {title}.
13
+ The proposed experiment is as follows: {idea}.
14
+ You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
15
+
16
+ First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
17
+
18
+ Note that we already provide the vanilla baseline results, so you do not need to re-run it.
19
+
20
+ For reference, the baseline results are as follows:
21
+
22
+ {baseline_results}
23
+
24
+ After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
25
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
26
+ You can then implement the next thing on your list."""
27
+
28
+
29
+ # RUN EXPERIMENT
30
+ def run_experiment(folder_name, run_num, timeout=7200):
31
+ cwd = osp.abspath(folder_name)
32
+ # COPY CODE SO WE CAN SEE IT.
33
+ shutil.copy(
34
+ osp.join(folder_name, "experiment.py"),
35
+ osp.join(folder_name, f"run_{run_num}.py"),
36
+ )
37
+
38
+ # LAUNCH COMMAND
39
+ command = [
40
+ "python",
41
+ "experiment.py",
42
+ f"--out_dir=run_{run_num}",
43
+ ]
44
+ try:
45
+ result = subprocess.run(
46
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
47
+ )
48
+
49
+ if result.stderr:
50
+ print(result.stderr, file=sys.stderr)
51
+
52
+ if result.returncode != 0:
53
+ print(f"Run {run_num} failed with return code {result.returncode}")
54
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
55
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
56
+ print(f"Run failed with the following error {result.stderr}")
57
+ stderr_output = result.stderr
58
+ if len(stderr_output) > MAX_STDERR_OUTPUT:
59
+ stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
60
+ next_prompt = f"Run failed with the following error {stderr_output}"
61
+ else:
62
+ with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
63
+ results = json.load(f)
64
+ results = {k: v["means"] for k, v in results.items()}
65
+
66
+ next_prompt = f"""Run {run_num} completed. Here are the results:
67
+ {results}
68
+
69
+ Decide if you need to re-plan your experiments given the result (you often will not need to).
70
+
71
+ Someone else will be using `notes.txt` to perform a writeup on this in the future.
72
+ Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
73
+
74
+ Then, implement the next thing on your list.
75
+ We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
76
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
77
+ If you are finished with experiments, respond with 'ALL_COMPLETED'."""
78
+ return result.returncode, next_prompt
79
+ except TimeoutExpired:
80
+ print(f"Run {run_num} timed out after {timeout} seconds")
81
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
82
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
83
+ next_prompt = f"Run timed out after {timeout} seconds"
84
+ return 1, next_prompt
85
+
86
+
87
+ # RUN PLOTTING
88
+ def run_plotting(folder_name, timeout=600):
89
+ cwd = osp.abspath(folder_name)
90
+ # LAUNCH COMMAND
91
+ command = [
92
+ "python",
93
+ "plot.py",
94
+ ]
95
+ try:
96
+ result = subprocess.run(
97
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
98
+ )
99
+
100
+ if result.stderr:
101
+ print(result.stderr, file=sys.stderr)
102
+
103
+ if result.returncode != 0:
104
+ print(f"Plotting failed with return code {result.returncode}")
105
+ next_prompt = f"Plotting failed with the following error {result.stderr}"
106
+ else:
107
+ next_prompt = ""
108
+ return result.returncode, next_prompt
109
+ except TimeoutExpired:
110
+ print(f"Plotting timed out after {timeout} seconds")
111
+ next_prompt = f"Plotting timed out after {timeout} seconds"
112
+ return 1, next_prompt
113
+
114
+
115
+ # PERFORM EXPERIMENTS
116
+ def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
117
+ ## RUN EXPERIMENT
118
+ current_iter = 0
119
+ run = 1
120
+ next_prompt = coder_prompt.format(
121
+ title=idea["Title"],
122
+ idea=idea["Experiment"],
123
+ max_runs=MAX_RUNS,
124
+ baseline_results=baseline_results,
125
+ )
126
+ while run < MAX_RUNS + 1:
127
+ if current_iter >= MAX_ITERS:
128
+ print("Max iterations reached")
129
+ break
130
+ coder_out = coder.run(next_prompt)
131
+ print(coder_out)
132
+ if "ALL_COMPLETED" in coder_out:
133
+ break
134
+ return_code, next_prompt = run_experiment(folder_name, run)
135
+ if return_code == 0:
136
+ run += 1
137
+ current_iter = 0
138
+ current_iter += 1
139
+ if current_iter >= MAX_ITERS:
140
+ print("Not all experiments completed.")
141
+ return False
142
+
143
+ current_iter = 0
144
+ next_prompt = """
145
+ Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
146
+
147
+ In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
148
+
149
+ Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
150
+
151
+ We will be running the command `python plot.py` to generate the plots.
152
+ """
153
+ while True:
154
+ coder_out = coder.run(next_prompt)
155
+ return_code, next_prompt = run_plotting(folder_name)
156
+ current_iter += 1
157
+ if return_code == 0 or current_iter >= MAX_ITERS:
158
+ break
159
+ next_prompt = """
160
+ Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
161
+
162
+ Somebody else will be using `notes.txt` to write a report on this in the future.
163
+ """
164
+ coder.run(next_prompt)
165
+
166
+ return True
ai_scientist/.ipynb_checkpoints/perform_writeup-checkpoint.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import shutil
7
+ import subprocess
8
+ from typing import Optional, Tuple
9
+
10
+ from strictjson import strict_json
11
+
12
+ from ai_scientist.generate_ideas import search_for_papers
13
+ from ai_scientist.llm import (
14
+ allchoices,
15
+ extract_json_between_markers,
16
+ get_response_from_llm,
17
+ llm_json_auto_correct,
18
+ )
19
+
20
+
21
+ def format_citation_first_json(text):
22
+ res = strict_json(
23
+ system_prompt="You are a JSON formatter",
24
+ user_prompt=text,
25
+ return_as_json=True,
26
+ output_format={
27
+ "Description": "A precise description of the required edit, along with the proposed text and location where it should be made",
28
+ "Query": "The search query to find the paper (e.g. attention is all you need)",
29
+ },
30
+ llm=llm_json_auto_correct,
31
+ )
32
+ text = json.loads(res)
33
+ return text
34
+
35
+
36
+ def format_citation_second_json(text):
37
+ res = strict_json(
38
+ system_prompt="You are a JSON formatter",
39
+ user_prompt=text,
40
+ return_as_json=True,
41
+ output_format={
42
+ "Selected": "A list of the indices of the selected papers to be cited, e.g. '[0, 1]'. Can be '[]' if no papers are selected. This must be a string",
43
+ "Description": "Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex",
44
+ },
45
+ llm=llm_json_auto_correct,
46
+ )
47
+ text = json.loads(res)
48
+ return text
49
+
50
+
51
+ # GENERATE LATEX
52
+ def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_corrections=5):
53
+ folder = osp.abspath(folder_name)
54
+ cwd = osp.join(folder, "latex") # Fixed potential issue with path
55
+ writeup_file = osp.join(cwd, "template.tex")
56
+
57
+ # Check all references are valid and in the references.bib file
58
+ with open(writeup_file, "r") as f:
59
+ tex_text = f.read()
60
+ cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
61
+ references_bib = re.search(
62
+ r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
63
+ tex_text,
64
+ re.DOTALL,
65
+ )
66
+ if references_bib is None:
67
+ print("No references.bib found in template.tex")
68
+ return
69
+ bib_text = references_bib.group(1)
70
+ cites = [cite.strip() for item in cites for cite in item.split(",")]
71
+ for cite in cites:
72
+ if cite not in bib_text:
73
+ print(f"Reference {cite} not found in references.")
74
+ prompt = f"""Reference {cite} not found in references.bib. Is this included under a different name?
75
+ If so, please modify the citation in template.tex to match the name in references.bib at the top. Otherwise, remove the cite."""
76
+ coder.run(prompt)
77
+
78
+ # Check all included figures are actually in the directory.
79
+ with open(writeup_file, "r") as f:
80
+ tex_text = f.read()
81
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
82
+ all_figs = [f for f in os.listdir(folder) if f.endswith(".png")]
83
+ for figure in referenced_figs:
84
+ if figure not in all_figs:
85
+ print(f"Figure {figure} not found in directory.")
86
+ prompt = f"""The image {figure} not found in the directory. The images in the directory are: {all_figs}.
87
+ Please ensure that the figure is in the directory and that the filename is correct. Check the notes to see what each figure contains."""
88
+ coder.run(prompt)
89
+
90
+ # Remove duplicate figures.
91
+ with open(writeup_file, "r") as f:
92
+ tex_text = f.read()
93
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
94
+ duplicates = {x for x in referenced_figs if referenced_figs.count(x) > 1}
95
+ if duplicates:
96
+ for dup in duplicates:
97
+ print(f"Duplicate figure found: {dup}.")
98
+ prompt = f"""Duplicate figures found: {dup}. Ensure any figure is only included once.
99
+ If duplicated, identify the best location for the figure and remove any other."""
100
+ coder.run(prompt)
101
+
102
+ # Remove duplicate section headers.
103
+ with open(writeup_file, "r") as f:
104
+ tex_text = f.read()
105
+ sections = re.findall(r"\\section{([^}]*)}", tex_text)
106
+ duplicates = {x for x in sections if sections.count(x) > 1}
107
+ if duplicates:
108
+ for dup in duplicates:
109
+ print(f"Duplicate section header found: {dup}")
110
+ prompt = f"""Duplicate section header found: {dup}. Ensure any section header is declared once.
111
+ If duplicated, identify the best location for the section header and remove any other."""
112
+ coder.run(prompt)
113
+
114
+ # Iteratively fix any LaTeX bugs
115
+ for i in range(num_error_corrections):
116
+ # Filter trivial bugs in chktex
117
+ check_output = os.popen(f"chktex {writeup_file} -q -n2 -n24 -n13 -n1").read()
118
+ if check_output:
119
+ prompt = f"""Please fix the following LaTeX errors in `template.tex` guided by the output of `chktek`:
120
+ {check_output}.
121
+
122
+ Make the minimal fix required and do not remove or change any packages.
123
+ Pay attention to any accidental uses of HTML syntax, e.g. </end instead of \\end.
124
+ """
125
+ coder.run(prompt)
126
+ else:
127
+ break
128
+ compile_latex(cwd, pdf_file, timeout=timeout)
129
+
130
+
131
+ def compile_latex(cwd, pdf_file, timeout=30):
132
+ print("GENERATING LATEX")
133
+
134
+ commands = [
135
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
136
+ ["bibtex", "template"],
137
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
138
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
139
+ ]
140
+
141
+ for command in commands:
142
+ try:
143
+ result = subprocess.run(
144
+ command,
145
+ cwd=cwd,
146
+ stdout=subprocess.PIPE,
147
+ stderr=subprocess.PIPE,
148
+ text=True,
149
+ timeout=timeout,
150
+ )
151
+ print("Standard Output:\n", result.stdout)
152
+ print("Standard Error:\n", result.stderr)
153
+ except subprocess.TimeoutExpired:
154
+ print(f"Latex timed out after {timeout} seconds")
155
+ except subprocess.CalledProcessError as e:
156
+ print(f"Error running command {' '.join(command)}: {e}")
157
+
158
+ print("FINISHED GENERATING LATEX")
159
+
160
+ # Attempt to move the PDF to the desired location
161
+ try:
162
+ shutil.move(osp.join(cwd, "template.pdf"), pdf_file)
163
+ except FileNotFoundError:
164
+ print("Failed to rename PDF.")
165
+
166
+
167
+ per_section_tips = {
168
+ "Abstract": """
169
+ - TL;DR of the paper
170
+ - What are we trying to do and why is it relevant?
171
+ - Why is this hard?
172
+ - How do we solve it (i.e. our contribution!)
173
+ - How do we verify that we solved it (e.g. Experiments and results)
174
+
175
+ Please make sure the abstract reads smoothly and is well-motivated. This should be one continuous paragraph with no breaks between the lines.
176
+ """,
177
+ "Introduction": """
178
+ - Longer version of the Abstract, i.e. of the entire paper
179
+ - What are we trying to do and why is it relevant?
180
+ - Why is this hard?
181
+ - How do we solve it (i.e. our contribution!)
182
+ - How do we verify that we solved it (e.g. Experiments and results)
183
+ - New trend: specifically list your contributions as bullet points
184
+ - Extra space? Future work!
185
+ """,
186
+ "Related Work": """
187
+ - Academic siblings of our work, i.e. alternative attempts in literature at trying to solve the same problem.
188
+ - Goal is to “Compare and contrast” - how does their approach differ in either assumptions or method? If their method is applicable to our Problem Setting I expect a comparison in the experimental section. If not, there needs to be a clear statement why a given method is not applicable.
189
+ - Note: Just describing what another paper is doing is not enough. We need to compare and contrast.
190
+ """,
191
+ "Background": """
192
+ - Academic Ancestors of our work, i.e. all concepts and prior work that are required for understanding our method.
193
+ - Usually includes a subsection, Problem Setting, which formally introduces the problem setting and notation (Formalism) for our method. Highlights any specific assumptions that are made that are unusual.
194
+ - Note: If our paper introduces a novel problem setting as part of its contributions, it's best to have a separate Section.
195
+ """,
196
+ "Method": """
197
+ - What we do. Why we do it. All described using the general Formalism introduced in the Problem Setting and building on top of the concepts / foundations introduced in Background.
198
+ """,
199
+ "Experimental Setup": """
200
+ - How do we test that our stuff works? Introduces a specific instantiation of the Problem Setting and specific implementation details of our Method for this Problem Setting.
201
+ - Do not imagine unknown hardware details.
202
+ - Includes a description of the dataset, evaluation metrics, important hyperparameters, and implementation details.
203
+ """,
204
+ "Results": """
205
+ - Shows the results of running Method on our problem described in Experimental Setup.
206
+ - Includes statements on hyperparameters and other potential issues of fairness.
207
+ - Only includes results that have actually been run and saved in the logs. Do not hallucinate results that don't exist.
208
+ - If results exist: compares to baselines and includes statistics and confidence intervals.
209
+ - If results exist: includes ablation studies to show that specific parts of the method are relevant.
210
+ - Discusses limitations of the method.
211
+ - Make sure to include all the results from the experiments, and include all relevant figures.
212
+ """,
213
+ "Conclusion": """
214
+ - Brief recap of the entire paper.
215
+ - To keep going with the analogy, you can think of future work as (potential) academic offspring.
216
+ """,
217
+ }
218
+
219
+ error_list = """- Unenclosed math symbols
220
+ - Only reference figures that exist in our directory
221
+ - LaTeX syntax errors
222
+ - Numerical results that do not come from explicit experiments and logs
223
+ - Repeatedly defined figure labels
224
+ - References to papers that are not in the .bib file, DO NOT ADD ANY NEW CITATIONS!
225
+ - Unnecessary verbosity or repetition, unclear text
226
+ - Results or insights in the `notes.txt` that have not yet need included
227
+ - Any relevant figures that have not yet been included in the text
228
+ - Closing any \\begin{{figure}} with a \\end{{figure}} and \\begin{{table}} with a \\end{{table}}, etc.
229
+ - Duplicate headers, e.g. duplicated \\section{{Introduction}} or \\end{{document}}
230
+ - Unescaped symbols, e.g. shakespeare_char should be shakespeare\\_char in text
231
+ - Incorrect closing of environments, e.g. </end{{figure}}> instead of \\end{{figure}}
232
+ """
233
+
234
+ refinement_prompt = (
235
+ """Great job! Now criticize and refine only the {section} that you just wrote.
236
+ Make this complete in this pass, do not leave any placeholders.
237
+
238
+ Pay particular attention to fixing any errors such as:
239
+ """
240
+ + error_list
241
+ )
242
+
243
+ second_refinement_prompt = (
244
+ """Criticize and refine the {section} only. Recall the advice:
245
+ {tips}
246
+ Make this complete in this pass, do not leave any placeholders.
247
+
248
+ Pay attention to how it fits in with the rest of the paper.
249
+ Identify any redundancies (e.g. repeated figures or repeated text), if there are any, decide where in the paper things should be cut.
250
+ Identify where we can save space, and be more concise without weakening the message of the text.
251
+ Fix any remaining errors as before:
252
+ """
253
+ + error_list
254
+ )
255
+
256
+ # CITATION HELPERS
257
+ citation_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
258
+ You have already written an initial draft of the paper and now you are looking to add missing citations to related papers throughout the paper.
259
+ The related work section already has some initial comments on which papers to add and discuss.
260
+
261
+ Focus on completing the existing write-up and do not add entirely new elements unless necessary.
262
+ Ensure every point in the paper is substantiated with sufficient evidence.
263
+ Feel free to add more cites to a particular point if there is only one or two references.
264
+ Ensure no paper is cited without a corresponding reference in the `references.bib` file.
265
+ Ensure each paragraph of the related work has sufficient background, e.g. a few papers cited.
266
+ You will be given access to the Semantic Scholar API, only add citations that you have found using the API.
267
+ Aim to discuss a broad range of relevant papers, not just the most popular ones.
268
+ Make sure not to copy verbatim from prior literature to avoid plagiarism.
269
+
270
+ You will be prompted to give a precise description of where and how to add the cite, and a search query for the paper to be cited.
271
+ Finally, you will select the most relevant cite from the search results (top 10 results will be shown).
272
+ You will have {total_rounds} rounds to add to the references, but do not need to use them all.
273
+
274
+ DO NOT ADD A CITATION THAT ALREADY EXISTS!"""
275
+
276
+ citation_first_prompt = '''Round {current_round}/{total_rounds}:
277
+
278
+ You have written this LaTeX draft so far:
279
+
280
+ """
281
+ {draft}
282
+ """
283
+
284
+ Identify the most important citation that you still need to add, and the query to find the paper.
285
+
286
+ Respond in the following format:
287
+
288
+ THOUGHT:
289
+ <THOUGHT>
290
+
291
+ RESPONSE:
292
+ ```json
293
+ <JSON>
294
+ ```
295
+
296
+ In <THOUGHT>, first briefly reason over the paper and identify where citations should be added.
297
+ If no more citations are needed, add "No more citations needed" to your thoughts.
298
+ Do not add "No more citations needed" if you are adding citations this round.
299
+
300
+ In <JSON>, respond in JSON format with the following fields:
301
+ - "Description": A precise description of the required edit, along with the proposed text and location where it should be made.
302
+ - "Query": The search query to find the paper (e.g. attention is all you need).
303
+
304
+ Ensure the description is sufficient to make the change without further context. Someone else will make the change.
305
+ The query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
306
+ This JSON will be automatically parsed, so ensure the format is precise.'''
307
+
308
+ citation_second_prompt = """Search has recovered the following articles:
309
+
310
+ {papers}
311
+
312
+ Respond in the following format:
313
+
314
+ THOUGHT:
315
+ <THOUGHT>
316
+
317
+ RESPONSE:
318
+ ```json
319
+ <JSON>
320
+ ```
321
+
322
+ In <THOUGHT>, first briefly reason over the search results and identify which citation best fits your paper and the location is to be added at.
323
+ If none are appropriate, add "Do not add any" to your thoughts.
324
+
325
+ In <JSON>, respond in JSON format with the following fields:
326
+ - "Selected": A list of the indices of the selected papers to be cited, e.g. "[0, 1]". Can be "[]" if no papers are selected. This must be a string.
327
+ - "Description": Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex!!!
328
+
329
+ Do not select papers that are already in the `references.bib` file at the top of the draft, or if the same citation exists under a different name.
330
+ This JSON will be automatically parsed, so ensure the format is precise."""
331
+
332
+
333
+ def get_citation_aider_prompt(
334
+ client, model, draft, current_round, total_rounds
335
+ ) -> Tuple[Optional[str], bool]:
336
+ msg_history = []
337
+ try:
338
+ text, msg_history = get_response_from_llm(
339
+ citation_first_prompt.format(
340
+ draft=draft, current_round=current_round, total_rounds=total_rounds
341
+ ),
342
+ client=client,
343
+ model=model,
344
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
345
+ msg_history=msg_history,
346
+ )
347
+ if "No more citations needed" in text:
348
+ print("No more citations needed.")
349
+ return None, True
350
+
351
+ ## PARSE OUTPUT
352
+ json_output = format_citation_first_json(text)
353
+ assert json_output is not None, "Failed to extract JSON from LLM output"
354
+ query = json_output["Query"]
355
+ papers = search_for_papers(query)
356
+ except Exception as e:
357
+ print(f"Error: {e}")
358
+ return None, False
359
+
360
+ if papers is None:
361
+ print("No papers found.")
362
+ return None, False
363
+
364
+ paper_strings = []
365
+ for i, paper in enumerate(papers):
366
+ paper_strings.append(
367
+ """{i}: {title}. {authors}. {venue}, {year}.\nAbstract: {abstract}""".format(
368
+ i=i,
369
+ title=paper["title"],
370
+ authors=paper["authors"],
371
+ venue=paper["venue"],
372
+ year=paper["year"],
373
+ abstract=paper["abstract"],
374
+ )
375
+ )
376
+ papers_str = "\n\n".join(paper_strings)
377
+
378
+ try:
379
+ text, msg_history = get_response_from_llm(
380
+ citation_second_prompt.format(
381
+ papers=papers_str,
382
+ current_round=current_round,
383
+ total_rounds=total_rounds,
384
+ ),
385
+ client=client,
386
+ model=model,
387
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
388
+ msg_history=msg_history,
389
+ )
390
+ if "Do not add any" in text:
391
+ print("Do not add any.")
392
+ return None, False
393
+ ## PARSE OUTPUT
394
+ json_output = format_citation_second_json(text)
395
+ assert json_output is not None, "Failed to extract JSON from LLM output"
396
+ desc = json_output["Description"]
397
+ selected_papers = json_output["Selected"]
398
+ selected_papers = str(selected_papers)
399
+
400
+ # convert to list
401
+ if selected_papers != "[]":
402
+ selected_papers = list(map(int, selected_papers.strip("[]").split(",")))
403
+ assert all(
404
+ [0 <= i < len(papers) for i in selected_papers]
405
+ ), "Invalid paper index"
406
+ bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers]
407
+ bibtex_string = "\n".join(bibtexs)
408
+ else:
409
+ return None, False
410
+
411
+ except Exception as e:
412
+ print(f"Error: {e}")
413
+ return None, False
414
+
415
+ # Add citation to draft
416
+ aider_format = '''The following citations have just been added to the end of the `references.bib` file definition at the top of the file:
417
+ """
418
+ {bibtex}
419
+ """
420
+ You do not need to add them yourself.
421
+ ABSOLUTELY DO NOT ADD IT AGAIN!!!
422
+
423
+ Make the proposed change to the draft incorporating these new cites:
424
+ {description}
425
+
426
+ Use your judgment for whether these should be cited anywhere else.
427
+ Make sure that any citation precisely matches the name in `references.bib`. Change its name to the correct name in the bibtex if needed.
428
+ Ensure the citation is well-integrated into the text.'''
429
+
430
+ aider_prompt = (
431
+ aider_format.format(bibtex=bibtex_string, description=desc)
432
+ + """\n You must use \cite or \citet to reference papers, do not manually type out author names."""
433
+ )
434
+ return aider_prompt, False
435
+
436
+
437
+ # PERFORM WRITEUP
438
+ def perform_writeup(
439
+ idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20
440
+ ):
441
+ # CURRENTLY ASSUMES LATEX
442
+ abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
443
+
444
+ First, please fill in the "Title" and "Abstract" sections of the writeup.
445
+
446
+ Some tips are provided below:
447
+ {per_section_tips["Abstract"]}
448
+
449
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
450
+
451
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
452
+ """
453
+ coder_out = coder.run(abstract_prompt)
454
+ coder_out = coder.run(
455
+ refinement_prompt.format(section="Abstract")
456
+ .replace(r"{{", "{")
457
+ .replace(r"}}", "}")
458
+ )
459
+ for section in [
460
+ "Introduction",
461
+ "Background",
462
+ "Method",
463
+ "Experimental Setup",
464
+ "Results",
465
+ "Conclusion",
466
+ ]:
467
+ section_prompt = f"""Please fill in the {section} of the writeup. Some tips are provided below:
468
+ {per_section_tips[section]}
469
+
470
+ Be sure to use \cite or \citet where relevant, referring to the works provided in the file.
471
+ Do not cite anything that is not already in `references.bib`. Do not add any new entries to this.
472
+
473
+ Keep the experimental results (figures and tables) only in the Results section, and make sure that any captions are filled in.
474
+ In this pass, do not reference anything in later sections of the paper.
475
+
476
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
477
+
478
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
479
+ """
480
+ coder_out = coder.run(section_prompt)
481
+ coder_out = coder.run(
482
+ refinement_prompt.format(section=section)
483
+ .replace(r"{{", "{")
484
+ .replace(r"}}", "}")
485
+ )
486
+
487
+ # SKETCH THE RELATED WORK
488
+ section_prompt = f"""Please fill in the Related Work of the writeup. Some tips are provided below:
489
+
490
+ {per_section_tips["Related Work"]}
491
+
492
+ For this section, very briefly sketch out the structure of the section, and clearly indicate what papers you intend to include.
493
+ Do this all in LaTeX comments using %.
494
+ The related work should be concise, only plan to discuss the most relevant work.
495
+ Do not modify `references.bib` to add any new citations, this will be filled in at a later stage.
496
+
497
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
498
+ """
499
+ coder_out = coder.run(section_prompt)
500
+
501
+ # Fill paper with cites.
502
+ for _ in range(num_cite_rounds):
503
+ with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
504
+ draft = f.read()
505
+ prompt, done = get_citation_aider_prompt(
506
+ cite_client, cite_model, draft, _, num_cite_rounds
507
+ )
508
+ if done:
509
+ break
510
+ if prompt is not None:
511
+ # extract bibtex string
512
+ bibtex_string = prompt.split('"""')[1]
513
+ # insert this into draft before the "\end{filecontents}" line
514
+ search_str = r"\end{filecontents}"
515
+ draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
516
+ with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
517
+ f.write(draft)
518
+ coder_out = coder.run(prompt)
519
+
520
+ coder_out = coder.run(
521
+ refinement_prompt.format(section="Related Work")
522
+ .replace(r"{{", "{")
523
+ .replace(r"}}", "}")
524
+ )
525
+
526
+ ## SECOND REFINEMENT LOOP
527
+ coder.run(
528
+ """Great job! Now that there is a complete draft of the entire paper, let's refine each section again.
529
+ First, re-think the Title if necessary. Keep this concise and descriptive of the paper's concept, but try by creative with it."""
530
+ )
531
+ for section in [
532
+ "Abstract",
533
+ "Related Work",
534
+ "Introduction",
535
+ "Background",
536
+ "Method",
537
+ "Experimental Setup",
538
+ "Results",
539
+ "Conclusion",
540
+ ]:
541
+ coder_out = coder.run(
542
+ second_refinement_prompt.format(
543
+ section=section, tips=per_section_tips[section]
544
+ )
545
+ .replace(r"{{", "{")
546
+ .replace(r"}}", "}")
547
+ )
548
+
549
+ generate_latex(coder, folder_name, f"{folder_name}/{idea['Name']}.pdf")
550
+
551
+
552
+ if __name__ == "__main__":
553
+ import json
554
+
555
+ from aider.coders import Coder
556
+ from aider.io import InputOutput
557
+ from aider.models import Model
558
+
559
+ parser = argparse.ArgumentParser(description="Perform writeup for a project")
560
+ parser.add_argument("--folder", type=str)
561
+ parser.add_argument("--no-writing", action="store_true", help="Only generate")
562
+ parser.add_argument(
563
+ "--model",
564
+ type=str,
565
+ default="gpt-4o-2024-05-13",
566
+ choices=allchoices,
567
+ help="Model to use for AI Scientist.",
568
+ )
569
+ args = parser.parse_args()
570
+ if args.model == "claude-3-5-sonnet-20240620":
571
+ import anthropic
572
+
573
+ print(f"Using Anthropic API with model {args.model}.")
574
+ client_model = "claude-3-5-sonnet-20240620"
575
+ client = anthropic.Anthropic()
576
+ elif args.model.startswith("bedrock") and "claude" in args.model:
577
+ import anthropic
578
+
579
+ # Expects: bedrock/<MODEL_ID>
580
+ client_model = args.model.split("/")[-1]
581
+
582
+ print(f"Using Amazon Bedrock with model {client_model}.")
583
+ client = anthropic.AnthropicBedrock()
584
+ elif args.model.startswith("vertex_ai") and "claude" in args.model:
585
+ import anthropic
586
+
587
+ # Expects: vertex_ai/<MODEL_ID>
588
+ client_model = args.model.split("/")[-1]
589
+
590
+ print(f"Using Vertex AI with model {client_model}.")
591
+ client = anthropic.AnthropicVertex()
592
+ elif args.model == "gpt-4o-2024-05-13":
593
+ import openai
594
+
595
+ print(f"Using OpenAI API with model {args.model}.")
596
+ client_model = "gpt-4o-2024-05-13"
597
+ client = openai.OpenAI()
598
+ elif args.model == "deepseek-coder-v2-0724":
599
+ import openai
600
+
601
+ print(f"Using OpenAI API with {args.model}.")
602
+ client_model = "deepseek-coder-v2-0724"
603
+ client = openai.OpenAI(
604
+ api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
605
+ )
606
+
607
+ # ----------------------------------------------------
608
+
609
+ elif args.model == "Qwen/Qwen2.5-72B-Instruct":
610
+ # elif args.model.startswith("hyperbolic"):
611
+ print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
612
+
613
+ import openai
614
+ import os
615
+ # client_model = args.model[11:]
616
+ client_model = args.model
617
+ client = openai.OpenAI(
618
+ api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
619
+ )
620
+ # ----------------------------------------------------
621
+ elif args.model == "llama3.1-405b":
622
+ import openai
623
+
624
+ print(f"Using OpenAI API with {args.model}.")
625
+ client_model = "meta-llama/llama-3.1-405b-instruct"
626
+ client = openai.OpenAI(
627
+ api_key=os.environ["OPENROUTER_API_KEY"],
628
+ base_url="https://openrouter.ai/api/v1",
629
+ )
630
+
631
+ elif args.model.startswith("ollama"):
632
+ import openai
633
+
634
+ print(f"Using Ollama with {args.model}.")
635
+ client_model = args.model.split("/")[-1]
636
+ client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
637
+ else:
638
+ raise ValueError(f"Model {args.model} not recognized.")
639
+
640
+
641
+ print("Make sure you cleaned the Aider logs if re-generating the writeup!")
642
+ folder_name = args.folder
643
+ idea_name = osp.basename(folder_name)
644
+ exp_file = osp.join(folder_name, "experiment.py")
645
+ vis_file = osp.join(folder_name, "plot.py")
646
+ notes = osp.join(folder_name, "notes.txt")
647
+
648
+ model = args.model
649
+
650
+ writeup_file = osp.join(folder_name, "latex", "template.tex")
651
+ ideas_file = osp.join(folder_name, "ideas.json")
652
+ with open(ideas_file, "r") as f:
653
+ ideas = json.load(f)
654
+ for idea in ideas:
655
+ if idea["Name"] in idea_name:
656
+ print(f"Found idea: {idea['Name']}")
657
+ break
658
+ if idea["Name"] not in idea_name:
659
+ raise ValueError(f"Idea {idea_name} not found")
660
+ fnames = [exp_file, writeup_file, notes]
661
+ io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
662
+
663
+
664
+
665
+ # AIDER CHAT INITIALIZATION CODE
666
+
667
+ if args.model == "deepseek-ai/DeepSeek-V2.5":
668
+ print("aider chosen deepseek")
669
+ main_model = Model("deepseek-chat")
670
+
671
+ elif args.model == "deepseek-coder-v2-0724":
672
+ main_model = Model("deepseek-ai/DeepSeek-V2.5")
673
+
674
+ elif args.model == "llama3.1-405b":
675
+ main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
676
+
677
+ # ----------------------------------------------------
678
+
679
+ elif args.model == "hyperbolic/Qwen/Qwen2.5-72B-Instruct":
680
+ print("aider model chosen")
681
+ # main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
682
+ main_model = Model("hyperbolic/Qwen/Qwen2.5-72B-Instruct")
683
+
684
+ elif args.model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
685
+ main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
686
+
687
+ # ----------------------------------------------------
688
+
689
+ else:
690
+ print("hello world")
691
+ main_model = Model(model)
692
+
693
+ coder = Coder.create(
694
+ main_model=main_model,
695
+ fnames=fnames,
696
+ io=io,
697
+ stream=False,
698
+ use_git=False,
699
+ edit_format="diff",
700
+ )
701
+ if args.no_writing:
702
+ generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
703
+ else:
704
+ try:
705
+ perform_writeup(idea, folder_name, coder, client, client_model)
706
+ except Exception as e:
707
+ print(f"Failed to perform writeup: {e}")
ai_scientist/Untitled.ipynb ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "011cecbb-41c0-4943-a37b-73ebd91d1e46",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import openai\n",
11
+ "import os \n",
12
+ "\n",
13
+ "client_model = \"deepseek-ai/DeepSeek-V2.5\"\n",
14
+ "client = openai.OpenAI(\n",
15
+ " # api_key=os.environ[\"DEEPSEEK_API_KEY\"], base_url=\"https://api.deepseek.com\"\n",
16
+ " api_key=DEEPSEEK_API_KEY, base_url=\"https://api.hyperbolic.xyz/v1\"\n",
17
+ ")\n"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 1,
23
+ "id": "d55b29d2-bcf9-4e92-9b79-1f2e20478069",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "DEEPSEEK_API_KEY=\"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJqb2huZG9lIn0.adqb-sj6G4xl7w7t9A4oRUBf1jNmcrc-0IcHjGhbz3o\""
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "86a236ee-791f-4ae7-92c4-c10d8b085f76",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "from aider."
38
+ ]
39
+ }
40
+ ],
41
+ "metadata": {
42
+ "kernelspec": {
43
+ "display_name": "Python 3 (ipykernel)",
44
+ "language": "python",
45
+ "name": "python3"
46
+ },
47
+ "language_info": {
48
+ "codemirror_mode": {
49
+ "name": "ipython",
50
+ "version": 3
51
+ },
52
+ "file_extension": ".py",
53
+ "mimetype": "text/x-python",
54
+ "name": "python",
55
+ "nbconvert_exporter": "python",
56
+ "pygments_lexer": "ipython3",
57
+ "version": "3.11.10"
58
+ }
59
+ },
60
+ "nbformat": 4,
61
+ "nbformat_minor": 5
62
+ }
ai_scientist/__init__.py ADDED
File without changes
ai_scientist/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (166 Bytes). View file
 
ai_scientist/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
ai_scientist/__pycache__/generate_ideas.cpython-311.pyc ADDED
Binary file (26 kB). View file
 
ai_scientist/__pycache__/generate_ideas.cpython-312.pyc ADDED
Binary file (22.7 kB). View file
 
ai_scientist/__pycache__/llm.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
ai_scientist/__pycache__/llm.cpython-312.pyc ADDED
Binary file (9.02 kB). View file
 
ai_scientist/__pycache__/perform_experiments.cpython-311.pyc ADDED
Binary file (8.02 kB). View file
 
ai_scientist/__pycache__/perform_experiments.cpython-312.pyc ADDED
Binary file (7.43 kB). View file
 
ai_scientist/__pycache__/perform_review.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
ai_scientist/__pycache__/perform_review.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
ai_scientist/__pycache__/perform_writeup.cpython-311.pyc ADDED
Binary file (34.1 kB). View file
 
ai_scientist/__pycache__/perform_writeup.cpython-312.pyc ADDED
Binary file (29.9 kB). View file
 
ai_scientist/fewshot_examples/132_automated_relational.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"The paper provides an interesting direction in the meta-learning field. In particular, it proposes to enhance meta learning performance by fully exploring relations across multiple tasks. To capture such information, the authors develop a heterogeneity-aware meta-learning framework by introducing a novel architecture--meta-knowledge graph, which can dynamically find the most relevant structure for new tasks.\",\n \"Strengths\": [\n \"The paper takes one of the most important issues of meta-learning: task heterogeneity. For me, the problem itself is real and practical.\",\n \"The proposed meta-knowledge graph is novel for capturing the relation between tasks and addressing the problem of task heterogeneity. Graph structure provides a more flexible way of modeling relations. The design for using the prototype-based relational graph to query the meta-knowledge graph is reasonable and interesting.\",\n \"This paper provides comprehensive experiments, including both qualitative analysis and quantitative results, to show the effectiveness of the proposed framework. The newly constructed Art-Multi dataset further enhances the difficulty of tasks and makes the performance more convincing.\"\n ],\n \"Weaknesses\": [\n \"Although the proposed method provides several ablation studies, I still suggest the authors conduct the following ablation studies to enhance the quality of the paper: (1) It might be valuable to investigate the modulation function. In the paper, the authors compare sigmoid, tanh, and Film layer. Can the authors analyze the results by reducing the number of gating parameters in Eq. 10 by sharing the gate value of each filter in Conv layers? (2) What is the performance of the proposed model by changing the type of aggregators?\",\n \"For the autoencoder aggregator, it would be better to provide more details about it, which seems not very clear to me.\",\n \"In the qualitative analysis (i.e., Figure 2 and Figure 3), the authors provide one visualization for each task. It would be more convincing if the authors can provide more cases in the rebuttal period.\"\n ],\n \"Originality\": 3,\n \"Quality\": 3,\n \"Clarity\": 3,\n \"Significance\": 4,\n \"Questions\": [\n \"Please address and clarify the cons above.\"\n ],\n \"Limitations\": [\n \"My major concern is about the clarity of the paper and some additional ablation models (see cons below). Hopefully the authors can address my concern in the rebuttal period.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 3,\n \"Presentation\": 3,\n \"Contribution\": 3,\n \"Overall\": 7,\n \"Confidence\": 5,\n \"Decision\": \"Accept\"\n}"
3
+ }
ai_scientist/fewshot_examples/132_automated_relational.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29ed4d84f6be5b9547097c2bc8bd57bfe197e91dc8f4ec9bcde6b545e7abe59
3
+ size 1348476
ai_scientist/fewshot_examples/132_automated_relational.txt ADDED
@@ -0,0 +1,1190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOMATED RELATIONAL META-LEARNING
2
+
3
+ **Anonymous authors**
4
+ Paper under double-blind review
5
+
6
+ ABSTRACT
7
+
8
+ In order to efficiently learn with small amount of data on new tasks, meta-learning
9
+ transfers knowledge learned from previous tasks to the new ones. However, a
10
+ critical challenge in meta-learning is the task heterogeneity which cannot be well
11
+ handled by traditional globally shared meta-learning methods. In addition, current
12
+ task-specific meta-learning methods may either suffer from hand-crafted structure
13
+ design or lack the capability to capture complex relations between tasks. In this
14
+ paper, motivated by the way of knowledge organization in knowledge bases, we
15
+ propose an automated relational meta-learning (ARML) framework that automatically extracts the cross-task relations and constructs the meta-knowledge graph.
16
+ When a new task arrives, it can quickly find the most relevant structure and tailor
17
+ the learned structure knowledge to the meta-learner. As a result, the proposed
18
+ framework not only addresses the challenge of task heterogeneity by a learned
19
+ meta-knowledge graph, but also increases the model interpretability. We conduct
20
+ extensive experiments on 2D toy regression and few-shot image classification and
21
+ the results demonstrate the superiority of ARML over state-of-the-art baselines.
22
+
23
+ 1 INTRODUCTION
24
+
25
+ Learning quickly with a few samples is the key characteristic of human intelligence, which remains a
26
+ daunting problem in machine intelligence. The mechanism of learning to learn (a.k.a., meta-learning)
27
+ is widely used to generalize and transfer prior knowledge learned from previous tasks to improve
28
+ the effectiveness of learning on new tasks, which has benefited various applications, ranging from
29
+ computer vision (Kang et al., 2019; Liu et al., 2019) to natural language processing (Gu et al., 2018;
30
+ Lin et al., 2019). Most of existing meta-learning algorithms learn a globally shared meta-learner
31
+ (e.g., parameter initialization (Finn et al., 2017), meta-optimizer (Ravi & Larochelle, 2016), metric
32
+ space (Snell et al., 2017)). However, globally shared meta-learners fail to handle tasks lying in
33
+ different distributions, which is known as task heterogeneity. Task heterogeneity has been regarded as
34
+ one of the most challenging issues in few-shot learning, and thus it is desirable to design meta-learning
35
+ models that effectively optimize each of the heterogeneous tasks.
36
+
37
+ The key challenge to deal with task heterogeneity is how to customize globally shared meta-learner
38
+ by using task-aware information? Recently, a handful of works try to solve the problem by learning
39
+ a task-specific representation for tailoring the transferred knowledge to each task (Oreshkin et al.,
40
+ 2018; Vuorio et al., 2019; Lee & Choi, 2018). However, the success of these methods relies on the
41
+ impaired knowledge generalization among closely correlated tasks (e.g., the tasks sampled from the
42
+ same distribution). Recently, learning the underlying structure among tasks provide a more effective
43
+ way for balancing the customization and generalization. Representatively, Yao et al. propose a
44
+ hierarchically structured meta-learning method to customize the globally shared knowledge to each
45
+ cluster in a hierarchical way (Yao et al., 2019). Nonetheless, the hierarchical clustering structure
46
+ completely relies on the handcrafted design which needs to be tuned carefully and may lack the
47
+ capability to capture complex relationships.
48
+
49
+ Hence, we are motivated to propose a framework to automatically extract underlying relational
50
+ structures from previously learned tasks and leverage those relational structures to facilitate knowledge
51
+ customization on a new task. This inspiration comes from the way of structuring knowledge in
52
+ knowledge bases (i.e., knowledge graphs). In knowledge bases, the underlying relational structures
53
+ across text entities are automatically constructed and applied to a new query to improve the searching
54
+ efficiency. In the meta-learning problem, similarly, we aim at automatically establishing the metaknowledge graph between prior knowledge learned from previous tasks. When a new task arrives,
55
+ it queries the meta-knowledge graph and quickly attends to the most relevant entities (nodes), and
56
+ then takes advantage of the relational knowledge structures between them to boost the learning
57
+ effectiveness with the limited training data.
58
+
59
+
60
+ -----
61
+
62
+ The proposed meta-learning framework is named as Automated Relational Meta-Learning (ARML).
63
+ Specifically, the ARML framework automatically builds the meta-knowledge graph from metatraining tasks to memorize and organize learned knowledge from historical tasks, where each vertex
64
+ represent one type of meta-knowledge (e.g., the common contour between birds and aircrafts). To
65
+ learn the meta-knowledge graph at meta-training time, for each task, we construct a prototype-based
66
+ relational graph for each class, where each vertex represents one prototype. The prototype-based
67
+ relational graph not only captures the underlying relationship behind samples, but alleviates the
68
+ potential effects of abnormal samples. The meta-knowledge graph is then learned by and summarizing
69
+ the information from the corresponding prototype-based relational graphs of meta-training tasks.
70
+ After constructing the meta-knowledge graph, when a new task comes in, the prototype-based
71
+ relational graph of the new task taps into the meta-knowledge graph for acquiring the most relevant
72
+ knowledge, which further enhances the task representation and facilitates its training process.
73
+
74
+ Our major contributions of the proposed ARML are three-fold: (1) it automatically constructs the
75
+ meta-knowledge graph to facilitate learning a new task; (2) it empirically outperforms state-of-the-art
76
+ meta-learning algorithms; (3) the meta-knowledge graph well captures the relationship among tasks
77
+ and improves the interpretability of meta-learning algorithms.
78
+
79
+ 2 RELATED WORK
80
+
81
+ Meta-learning, allowing machines to learn new skills or adapt to new environments rapidly with a
82
+ few training examples, has been demonstrated to be successful in both supervised learning tasks
83
+ (e.g., few-shot image classification) and reinforcement learning settings. There are mainly three
84
+ research lines of meta-learning: (1) black-box amortized methods design black-box meta-learners
85
+ (e.g., neural networks) to infer the model parameters (Ravi & Larochelle, 2016; Andrychowicz et al.,
86
+ 2016; Mishra et al., 2018); (2) gradient-based methods aim to learn an optimized initialization of
87
+ model parameters, which can be adapted to new tasks by a few steps of gradient descent (Finn et al.,
88
+ 2017; 2018; Lee & Choi, 2018); (3) non-parameteric methods combine parameteric meta-learners
89
+ and non-parameteric learners to learn an appropriate distance metric for few-shot classification (Snell
90
+ et al., 2017; Vinyals et al., 2016; Yang et al., 2018; Oreshkin et al., 2018; Yoon et al., 2019).
91
+
92
+ Our work is built upon the gradient-based meta-learning methods. In the line of gradient-based
93
+ meta-learning, most algorithms learn a globally shared meta-learners from all previous tasks (Finn
94
+ et al., 2017; Li et al., 2017; Flennerhag et al., 2019), to improve the effectiveness of learning process
95
+ on new tasks. However, these algorithms typically lack the ability to handle heterogeneous tasks
96
+ (i.e., tasks sample from sufficient different distributions). To tackle this challenge, recent works
97
+ tailor the globally shared initialization to different tasks by leveraging task-specific information (Lee
98
+ & Choi, 2018; Vuorio et al., 2019; Oreshkin et al., 2018) and using probabilistic models (Grant
99
+ et al., 2018; Yoon et al., 2018; Gordon et al., 2019). Recently, HSML customizes the global shared
100
+ initialization with a manually designed hierarchical clustering structure to balance the generalization
101
+ and customization between previous tasks (Yao et al., 2019). However, the hierarchical structure
102
+ may not accurately reflect the real structure since it highly relies on the hand-crafted design. In
103
+ addition, the clustering structure further constricts the complexity of relational structures. However, to
104
+ customize each task, our proposed ARML leverages the most relevant structure from meta-knowledge
105
+ graph which are automatically constructed by previous knowledge. Thus, ARML not only discovers
106
+ more accurate underlying structures to improve the effectiveness of meta-learning algorithms, but
107
+ also the meta-knowledge graph can further enhance the model interpretability.
108
+
109
+ 3 PRELIMINARIES
110
+
111
+ **Few-shot Learning** Considering a task Ti, the goal of few-shot learning is to learn a model with
112
+ a dataset Di = {Di[tr][,][ D]i[ts][}][, where the labeled training set][ D]i[tr] = {x[tr]j _[,][ y]j[tr][|∀][j][ ∈]_ [[1][, N][ tr][]][}][ only has a]
113
+ few samples and Di[ts] [represents the corresponding test set. A learning model (a.k.a., base model)][ f]
114
+ with parameters θ are used to evaluate the effectiveness on Di[ts] [by minimizing the expected empirical]
115
+ loss on Di[tr][, i.e.,][ L][(][D]T[tr]i _[, θ][)][, and obtain the optimal parameters][ θ][i][. For the regression problem, the loss]_
116
+ function is defined based on the mean square error (i.e., (xj _,yj_ )∈Di[tr] 2[) and for the clas-]
117
+
118
+ sification problem, the loss function uses the cross entropy loss (i.e., −[∥][f][P][θ][(]([x]x[j]j[)],y[−]j )[y]∈D[j][∥]i[tr][2] [log][ p][(][y][j][|][x][j][, f][θ][)][).]
119
+
120
+ Usually, optimizing and learning parameter θ for the task[P] _Ti with a few labeled training samples_
121
+ is difficult. To address this limitation, meta-learning provides us a new perspective to improve the
122
+ performance by leveraging knowledge from multiple tasks.
123
+
124
+
125
+ -----
126
+
127
+ **Meta-learning and Model-agnostic Meta-learning** In meta-learning, a sequence of tasks
128
+ _{T1, ..., TI_ _} are sampled from a task-level probability distribution p(T ), where each one is a few-shot_
129
+ learning task. To facilitate the adaption for incoming tasks, the meta-learning algorithm aims to find
130
+ a well-generalized meta-learner on I training tasks at meta-learning phase. At meta-testing phase, the
131
+ optimal meta-learner is applied to adapt the new tasks Tt. In this way, meta-learning algorithms are
132
+ capable of adapting to new tasks efficiently even with a shortage of training data for a new task.
133
+
134
+ Model-agnostic meta-learning (MAML) (Finn et al., 2017), one of the representative algorithms in
135
+ gradient-based meta-learning, regards the meta-learner as the initialization of parameter θ, i.e., θ0,
136
+ and learns a well-generalized initialization θ0[∗] [during the meta-training process. The optimization]
137
+ problem is formulated as (one gradient step as exemplary):
138
+
139
+
140
+ _θ0[∗]_ [:= arg min]
141
+ _θ0_
142
+
143
+
144
+ (fθi _,_ _i_ [) = arg min]
145
+ _L_ _D[ts]_ _θ0_
146
+ _i=1_
147
+
148
+ X
149
+
150
+
151
+ _L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][)][.] (1)
152
+ _i=1_
153
+
154
+ X
155
+
156
+
157
+ At the meta-testing phase, to obtain the adaptive parameter θt for each new task Tt, we finetune the
158
+ initialization of parameter θ0[∗] [by performing gradient updates a few steps, i.e.,][ f]θt [=][ f]θ0[∗] _t_ [)][.]
159
+
160
+ _[−][α][∇][θ]_ _[L][(][f][θ]_ _[,][D][tr]_
161
+
162
+ 4 METHODOLOGY
163
+
164
+ In this section, we introduce the details of the proposed ARML. To better explain how it works,
165
+ we show its framework in Figure 1. The goal of ARML is to facilitate the learning process of new
166
+ tasks by leveraging transferable knowledge learned from historical tasks. To achieve this goal, we
167
+ introduce a meta-knowledge graph, which is automatically constructed at the meta-training time, to
168
+ organize and memorize historical learned knowledge. Given a task, which is built as a prototypebased relational structure, it taps into the meta-knowledge graph to acquire relevant knowledge for
169
+ enhancing its own representation. The enhanced prototype representation further aggregate and
170
+ incorporate with meta-learner for fast and effective adaptions by utilizing a modulating function. In
171
+ the following subsections, we elaborate three key components: prototype-based sample structuring,
172
+ automated meta-knowledge graph construction and utilization, and task-specific knowledge fusion
173
+ and adaptation, respectively.
174
+
175
+ **Propagation**
176
+
177
+ **Prototype-based** **Meta-knowledge**
178
+
179
+ **Prototypes** **Relational** **Graph )**
180
+
181
+ **Structure ℛ#**
182
+
183
+ +#(,
184
+
185
+
186
+
187
+ … !"
188
+
189
+ **Aggregator**
190
+
191
+ ℒ( **Modulation**
192
+
193
+ **Aggregator**
194
+
195
+ +#(- ℒ' ∇%ℒ !"#
196
+
197
+ !#
198
+
199
+
200
+ Figure 1: The framework of ARML. For each task _i, ARML first builds a prototype-based relational_
201
+ _T_
202
+ structure Ri by mapping the training samples Di[tr] [into prototypes, with each prototype represents]
203
+ one class. Then, Ri interacts with the meta-knowledge graph G to acquire the most relevant historical
204
+ knowledge by information propagation. Finally, the task-specific modulation tailors the globally
205
+ shared initialization θ0 by aggregating of raw prototypes and enriched prototypes, which absorbs
206
+ relevant historical information from the meta-knowledge graph.
207
+
208
+ 4.1 PROTOTYPE-BASED SAMPLE STRUCTURING
209
+
210
+ Given a task which involves either classifications or regressions regarding a set of samples, we first
211
+ investigate the relationships among these samples. Such relationship is represented by a graph, called
212
+ prototype-based relational graph in this work, where the vertices in the graph denote the prototypes
213
+ of different classes while the edges and the corresponding edge weights are created based on the
214
+
215
+
216
+ -----
217
+
218
+ similarities between prototypes. Constructing the relational graph based on prototypes instead of raw
219
+ samples allows us to alleviate the issue raised by abnormal samples. As the abnormal samples, which
220
+ locate far away from normal samples, could pose significant concerns especially when only a limited
221
+ number of samples are available for training. Specifically, for classification problem, the prototype,
222
+ denoted by c[k]i
223
+
224
+ _[∈]_ [R][d][, is defined as:] _N_ _[tr]_
225
+
226
+
227
+ **c[k]i** [=]
228
+
229
+
230
+ _E(xj),_ (2)
231
+ _j=1_
232
+
233
+ X
234
+
235
+
236
+ _Nk[tr]_
237
+
238
+
239
+ where Nk[tr] [denotes the number of samples in class][ k][.][ E][ is an embedding function, which projects]
240
+ **xj into a hidden space where samples from the same class are located closer to each other while**
241
+ samples from different classes stay apart. For regression problem, it is not straightforward to construct
242
+ the prototypes explicitly based on class information. Therefore, we cluster samples by learning an
243
+ assignment matrix Pi R[K][×][N] _[tr]_ . Specifically, we formulate the process as:
244
+ _∈_
245
+
246
+ **Pi = Softmax(WpE** [T](X) + bp), c[k]i [=][ P]i[[][k][]][F] [(][X][)][,] (3)
247
+
248
+ where Pi[k] represents the k-th row of Pi. Thus, training samples are clustered to K clusters, which
249
+ serve as the representation of prototypes.
250
+
251
+ After calculating all prototype representations **c[k]i**
252
+ _{_ _[|∀][k][ ∈]_ [[1][, K][]][}][, which serve as the vertices in the the]
253
+ prototype-based relational graph Ri, we further define the edges and the corresponding edge weights.
254
+ The edge weight ARi (c[j]i _[,][ c]i[m][)][ between two prototypes][ c]i[j]_ [and][ c]i[m] [is gauged by the the similarity]
255
+ between them. Formally:
256
+
257
+ _ARi_ (c[j]i _[,][ c]i[m][) =][ σ][(][W]r[(][|][c][j]i_ _i_ _r[) +][ b]r[)][,]_ (4)
258
+
259
+ _[−]_ **[c][m][|][/γ]**
260
+
261
+ where Wr and br represents learnable parameters, γr is a scalar and σ is the Sigmoid function, which
262
+ normalizes the weight between 0 and 1. For simplicity, we denote the prototype-based relational graph
263
+ as Ri = (CRi _, ARi_ ), where CRi = {c[j]i _[|∀][j][ ∈]_ [[1][, K][]][} ∈] [R][K][×][d][ represent a set of vertices, with each]
264
+ one corresponds to the prototype from a class, while ARi = {|ARi (c[j]i _[,][ c]i[m][)][|∀][j, m][ ∈]_ [[1][, K][]][} ∈] [R][K][×][K]
265
+ gives the adjacency matrix, which indicates the proximity between prototypes.
266
+
267
+ 4.2 AUTOMATED META-KNOWLEDGE GRAPH CONSTRUCTION AND UTILIZATION
268
+
269
+ In this section, we first discuss how to organize and distill knowledge from historical learning process
270
+ and then expound how to leverage such knowledge to benefit the training of new tasks. To organize
271
+ and distill knowledge from historical learning process, we construct and maintain a meta-knowledge
272
+ graph. The vertices represent different types of meta-knowledge (e.g., the common contour between
273
+ aircrafts and birds) and the edges are automatically constructed and reflect the relationship between
274
+ meta-knowledge. When serving a new task, we refer to the meta-knowledge, which allows us to
275
+ efficiently and automatically identify relational knowledge from previous tasks. In this way, the
276
+ training of a new task can benefit from related training experience and get optimized much faster
277
+ than otherwise possible. In this paper, the meta-knowledge graph is automatically constructed at the
278
+ meta-training phase. The details of the construction are elaborated as follows:
279
+
280
+ Assuming the representation of an vertex g is given by h[g] _∈_ R[d], we define the meta-knowledge
281
+ graph as G = (HG, AG), where HG = {h[j]|∀j ∈ [1, G]} ∈ R[G][×][d] and AG = {AG(h[j], h[m])|∀j, m ∈
282
+
283
+ [1, G]} ∈ R[G][×][G] denote the vertex feature matrix and vertex adjacency matrix, respectively. To better
284
+ explain the construction of the meta-knowledge graph, we first discuss the vertex representation H .
285
+ _G_
286
+ During meta-training, tasks arrive one after another in a sequence and their corresponding vertices
287
+ representations are expected to be updated dynamically in a timely manner. Therefore, the vertex
288
+ representation of meta-knowledge graph are defined to get parameterized and learned at the training
289
+ time. Moreover, to encourage the diversity of meta-knowledge encoded in the meta-knowledge graph,
290
+ the vertex representations are randomly initialized. Analogous to the definition of weight in the
291
+ prototype-based relational graph Ri in equation 4, the weight between a pair of vertices j and m is
292
+ constructed as:
293
+ _A_ (h[j], h[m]) = σ(Wo( **h[j]** **h[m]** _/γo) + bo),_ (5)
294
+ _G_ _|_ _−_ _|_
295
+ where Wo and bo represent learnable parameters and γo is a scalar.
296
+
297
+ To enhance the learning of new tasks with involvement of historical knowledge, we query the
298
+ prototype-based relational graph in the meta-knowledge graph to obtain the relevant knowledge in
299
+ history. The ideal query mechanism is expected to optimize both graph representations simultaneously
300
+
301
+
302
+ -----
303
+
304
+ at the meta-training time, with the training of one graph facilitating the training of the other. In light
305
+ of this, we construct a super-graph Si by connecting the prototype-based relational graph Ri with the
306
+ meta-knowledge graph G for each task Ti. The union of the vertices in Ri and G contributes to the
307
+ vertices in the super-graph. The edges in Ri and G are also reserved in the super-graph. We connect
308
+ _Ri with G by creating links between the prototype-based relational graph with the meta-knowledge_
309
+ graph. The link between prototype c[j]i [in prototype-based relational graph and vertex][ h][m][ in meta-]
310
+ knowledge graph is weighted by the similarity between them. More precisely, for each prototype c[j]i [,]
311
+ the link weight AS (c[j]i _[,][ h][m][)][ is calculated by applying softmax over Euclidean distances between][ c][j]i_
312
+ and {h[m]|∀m ∈ [1, G]} as follows:
313
+
314
+ _AS_ (c[j]i _[,][ h][k][) =]_ _Kexp(−∥(c[j]i_ _[−]_ **[h][k][)][/γ][s][∥]2[2][/][2)]** _,_ (6)
315
+ _k[′]_ =1 [exp(][−∥][(][c]i[j] _[−]_ **[h][k][′][ )][/γ][s][∥]2[2][/][2)]**
316
+
317
+ where γs is a scaling factor. We denote the intra-adjacent matrix asP **AS = {AS** (c[j]i _[,][ h][m][)][|∀][j][ ∈]_
318
+
319
+ [1, K], m ∈ [1, G]} ∈ R[K][×][G]. Thus, for task Ti, the adjacent matrix and feature matrix of super-graph
320
+ _i = (Ai, Hi) is defined as Ai = (A_ _i_ _, A_ ; A[T] [= (][C][R]i [;][ H][G][)][ ∈]
321
+ _S_ _R_ _S_ _S_ _[,][ A][G][)][ ∈]_ [R][(][K][+][G][)][×][(][K][+][G][)][ and][ H][i]
322
+ R[(][K][+][G][)][×][d], respectively.
323
+
324
+ After constructing the super-graph Si, we are able to propagate the most relevant knowledge from
325
+ meta-knowledge graph G to the prototype-based relational graph Ri by introducing a Graph Neural
326
+ Networks (GNN). In this work, following the “message-passing” framework (Gilmer et al., 2017),
327
+ the GNN is formulated as:
328
+ **Hi[(][l][+1)]** = MP(Ai, H[(]i[l][)][;][ W][(][l][)][)][,] (7)
329
+ where MP(·) is the message passing function and has several possible implementations (Hamilton
330
+ et al., 2017; Kipf & Welling, 2017; Velickoviˇ c et al., 2018),´ **H[(]i[l][)]** is the vertex embedding after l
331
+ layers of GNN and W[(][l][)] is a learnable weight matrix of layer l. The input H[(0)]i = Hi. After stacking
332
+ _L GNN layers, we get the information-propagated feature representation for the prototype-based_
333
+ relational graph Ri as the top-K rows of Hi[(][L][)], which is denoted as **C[ˆ]** _Ri = {cˆ[j]i_ _[|][j][ ∈]_ [[1][, K][]][}][.]
334
+
335
+ 4.3 TASK-SPECIFIC KNOWLEDGE FUSION AND ADAPTATION
336
+
337
+ After propagating information form meta-knowledge graph to prototype-based relational graph, in
338
+ this section, we discuss how to learn a well-generalized meta-learner for fast and effective adaptions
339
+ to new tasks with limited training data. To tackle the challenge of task heterogeneity, in this
340
+ paper, we incorporate task-specific information to customize the globally shared meta-learner (e.g.,
341
+ initialization here) by leveraging a modulating function, which has been proven to be effective to
342
+ provide customized initialization in previous studies (Wang et al., 2019; Vuorio et al., 2019).
343
+
344
+ The modulating function relies on well-discriminated task representations, while it is difficult to learn
345
+ all representations by merely utilizing the loss signal derived from the test set Di[ts][. To encourage such]
346
+ stability, we introduce two reconstructions by utilizing two auto-encoders. There are two collections
347
+ of parameters, i.e, CRi and **C[ˆ]** _Ri, which contribute the most to the creation of the task-specific_
348
+ meta-learner. CRi express the raw prototype information without tapping into the meta-knowledge
349
+ graph, while **C[ˆ]** _Ri give the prototype representations after absorbing the relevant knowledge from the_
350
+ meta-knowledge graph. Therefore, the two reconstructions are built on CRi and **C[ˆ]** _Ri_ . To reconstruct
351
+ **CRi**, an aggregator AG[q](·) (e.g., recurrent network, fully connected layers) is involved to encode CRi
352
+ into a dense representation, which is further fed into a decoder AG[q]dec[(][·][)][ to achieve reconstructions.]
353
+ Then, the corresponded task representation qi of CRi is summarized by applying a mean pooling
354
+ operator over prototypes on the encoded dense representation. Formally,
355
+
356
+ _N_ _[tr]_
357
+
358
+
359
+ **qi = MeanPool(AG[q](CRi** )) =
360
+
361
+
362
+ (AG[q](c[j]i [))][,][ L][q][ =][ ∥][C][R]i _dec[(AG][q][(][C][R]i_ [))][∥]F[2] (8)
363
+ _j=1_ _[−]_ [AG][q]
364
+
365
+ X
366
+
367
+
368
+ _N_ _[tr]_
369
+
370
+
371
+ Similarly, we reconstruct **C[ˆ]** _Ri and get the corresponded task representation ti as follows:_
372
+
373
+ _N_ _[tr]_
374
+
375
+
376
+ **ti = MeanPool(AG[t]( C[ˆ]** _Ri_ )) =
377
+
378
+
379
+ _j=1(AG[t](ˆc[j]i_ [))][,][ L][t][ =][ ∥]C[ˆ] _Ri −_ AG[t]dec[(AG][t][( ˆ]CRi ))∥F[2] (9)
380
+
381
+ X
382
+
383
+
384
+ _N_ _[tr]_
385
+
386
+
387
+ The reconstruction errors in Equations 8 and 9 pose an extra constraint to enhance the training
388
+ stability, leading to improvement of task representation learning.
389
+
390
+
391
+ -----
392
+
393
+ **Algorithm 1 Meta-Training Process of ARML**
394
+
395
+ **Require: p(T ): distribution over tasks; K: Number of vertices in meta-knowledge graph; α: stepsize**
396
+ for gradient descent of each task (i.e., inner loop stepsize); β: stepsize for meta-optimization (i.e.,
397
+ outer loop stepsize); µ1, µ2: balancing factors in loss function
398
+
399
+ 1: Randomly initialize all learnable parameters Φ
400
+ 2: while not done do
401
+ 3: Sample a batch of tasks {Ti|i ∈ [1, I]} from p(T )
402
+
403
+ 4: **for all Ti do**
404
+
405
+ 5: Sample training set Di[tr] [and testing set][ D]i[ts]
406
+
407
+ 6: Construct the prototype-based relational graph Ri by computing prototype in equation 2
408
+ and weight in equation 4
409
+
410
+ 7: Compute the similarity between each prototype and meta-knowledge vertex in equation 6
411
+ and construct the super-graph Si
412
+
413
+ 8: Apply GNN on super-graph Si and get the information-propagated representation **C[ˆ]** _Ri_
414
+
415
+ 9: Aggregate CRi in equation 8 and **C[ˆ]** _Ri in equation 9 to get the representations qi, ti and_
416
+ reconstruction loss Lq, Lt
417
+
418
+ 10: Compute the task-specific initialization θ0i in equation 10 and update parameters θi =
419
+ _θ0i −_ _α∇θL(fθ, Di[tr][)]_
420
+
421
+ 11: **end for**
422
+
423
+ 12: Update Φ Φ _β_ Φ _Ii=1_ _i_ _[,][ D]i[ts][) +][ µ][i][L][t]_ [+][ µ][2][L][q]
424
+
425
+ 13: end while _←_ _−_ _∇_ _[L][(][f][θ]_
426
+
427
+ P
428
+
429
+
430
+ After getting the task representation qi and ti, the modulating function is then used to tailor the
431
+ task-specific information to the globally shared initialization θ0, which is formulated as:
432
+
433
+ _θ0i = σ(Wg(ti ⊕_ **qi) + bg) ◦** _θ0,_ (10)
434
+
435
+ where Wg and bg is learnable parameters of a fully connected layer. Note that we adopt the Sigmoid
436
+ gating as exemplary and more discussion about different modulating functions can be found in
437
+ ablation studies of Section 5.
438
+
439
+ For each task Ti, we perform the gradient descent process from θ0i and reach its optimal parameter θi.
440
+ Combining the reconstruction loss Lt and Lq with the meta-learning loss defined in equation 1, the
441
+ overall objective function of ARML is:
442
+
443
+ _I_
444
+
445
+ minΦ Φ Φ _L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][) +][ µ]1[L]t [+][ µ]2[L]q[,] (11)
446
+
447
+ _[L][all][ = min]_ _[L][ +][ µ][1][L][t][ +][ µ][2][L][q][ = min]_ _i=1_
448
+
449
+ X
450
+
451
+ where µ1 and µ2 are introduced to balance the importance of these three items. Φ represents all
452
+ learnable parameters. The algorithm of meta-training process of ARML is shown in Alg. 2. The
453
+ details of the meta-testing process of ARML are available in Appendix A.
454
+
455
+ 5 EXPERIMENTS
456
+
457
+ In this section, we conduct extensive experiments to demonstrate the effectiveness of the ARML on
458
+ 2D regression and few-shot classification with the goal of answering the following questions: (1) Can
459
+ ARML outperform other meta-learning methods?; (2) Can our proposed components improve the
460
+ learning performance?; (3) Can ARML framework improve the model interpretability by discovering
461
+ reasonable meta-knowledge graph?
462
+
463
+ 5.1 EXPERIMENTAL SETTINGS
464
+
465
+ **Methods for Comparison** We compare our proposed ARML with two types of baselines: gradientbased meta-learning algorithms and non-parameteric meta-learning algorithms.
466
+
467
+ _For gradient-based meta-learning methods: both globally shared methods (MAML (Finn et al.,_
468
+ 2017), Meta-SGD (Li et al., 2017)) and task-specific methods (MT-Net (Lee & Choi, 2018), MUMOMAML (Vuorio et al., 2019), HSML (Yao et al., 2019)) are considered for comparison.
469
+
470
+ _For non-parametric meta-learning methods: we select globally shared method Prototypical Network_
471
+ (ProtoNet) (Snell et al., 2017) and task-specific method TADAM (Oreshkin et al., 2018) as baselines.
472
+ Note that, following the traditional settings, non-parametric baselines are only used in few-shot
473
+ classification problem. The detailed implementations of baselines are discussed in Appendix B.3.
474
+
475
+
476
+ -----
477
+
478
+ **Hyperparameter Settings** For the aggregated function in autoencoder structure (AG[t], AG[t]dec
479
+ AG[q], AG[q]dec[), we use the GRU as the encoder and decoder in this autoencoder framework. We]
480
+ adopt one layer GCN (Kipf & Welling, 2017) with tanh activation as the implementation of GNN
481
+ in equation 7. For the modulation network, we try both sigmoid, tanh Film modulation and find that
482
+ sigmoid modulation performs better. Thus, in the future experiment, we use the sigmoid modulation as
483
+ modulating function. More detailed discussion about experiment settings are presented in Appendix B.
484
+
485
+ 5.2 2D REGRESSION
486
+
487
+ **Dataset Description.** In 2D regression problem, we adopt the similar regression problem settings
488
+ as (Finn et al., 2018; Vuorio et al., 2019; Yao et al., 2019; Rusu et al., 2019), which includes several
489
+ families of functions. In this paper, to model more complex relational structures, we design a 2D
490
+ regression problem rather than traditional 1D regression. Input x ∼ _U_ [0.0, 5.0] and y ∼ _U_ [0.0, 5.0]
491
+ are sampled randomly and random Gaussian noisy with standard deviation 0.3 is added to the
492
+ output. Furthermore, six underlying functions are selected, including (1) Sinusoids: z(x, y) =
493
+ _assin(wsx + bs), where as ∼_ _U_ [0.1, 5.0], bs ∼ _U_ [0, 2π] ws ∼ _U_ [0.8, 1.2]; (2) Line: z(x, y) = alx + bl,
494
+ where al ∼ _U_ [−3.0, 3.0], bl ∼ _U_ [−3.0, 3.0]; (3) Quadratic: z(x, y) = aqx[2] + bqx + cq, where aq ∼
495
+ _U_ [−0.2, 0.2], bq ∼ _U_ [−2.0, 2.0], cq ∼ _U_ [−3.0, 3.0]; (4) Cubic: z(x, y) = acx[3] + bcx[2] + ccx + dc,
496
+ where ac ∼ _U_ [−0.1, 0.1], bc ∼ _U_ [−0.2, 0.2], cc ∼ _U_ [−2.0, 2.0], dc ∼ _U_ [−3.0, 3.0]; (5) Quadratic
497
+ _Surface: z(x, y) = aqsx[2]_ + bqsy[2], where aqs ∼ _U_ [−1.0, 1.0], bqs ∼ _U_ [−1.0, 1.0]; (6) Ripple: z(x, y) =
498
+ _sin(−ar(x[2]_ + y[2])) + br, where ar ∼ _U_ [−0.2, 0.2], br ∼ _U_ [−3.0, 3.0]. Note that, function 1-4 are
499
+ located in the subspace of y = 1. Follow (Finn et al., 2017), we use two fully connected layers with
500
+ 40 neurons as the base model. The number of vertices of meta-knowledge graph is set as 6.
501
+
502
+ **Results and Analysis.** In Figure 2, we summarize the interpretation of meta-knowledge graph
503
+ (see top figure) and the the qualitative results (see bottom table) of 10-shot 2D regression. In the
504
+ bottom table, we can observe that ARML achieves the best performance as compared to competitive
505
+ gradient-based meta-learning methods, i.e., globally shared models and task-specific models. This
506
+ finding demonstrates that the meta-knowledge graph is necessary to model and capture task-specific
507
+ information. The superior performance can also be interpreted in the top figure. In the left, we
508
+ show the heatmap between prototypes and meta-knowledge vertices (deeper color means higher
509
+ similarity). We can see that sinusoids and line activate V1 and V4, which may represent curve and
510
+ line, respectively. V1 and V4 also contribute to quadratic and quadratic surface, which also show
511
+ the similarity between these two families of functions. V3 is activated in P0 of all functions and the
512
+ quadratic surface and ripple further activate V1 in P0, which may show the different between 2D
513
+ functions and 3D functions (sinusoid, line, quadratic and cubic lie in the subspace). Specifically,
514
+ in the right figure, we illustrate the meta-knowledge graph, where we set a threshold to filter the
515
+ link with low similarity score and show the rest. We can see that V3 is the most popular vertice and
516
+
517
+ |Model|MAML Meta-SGD MT-Net MUMOMAML HSML ARML|
518
+ |---|---|
519
+
520
+
521
+ |10-shot|2.292 ± 0.163 2.908 ± 0.229 1.757 ± 0.120 0.523 ± 0.036 0.494 ± 0.038 0.438 ± 0.029|
522
+ |---|---|
523
+
524
+
525
+ connected with V1, V5 (represent curve) and V4 (represent line). V1 is further connected with V5,
526
+ demonstrating the similarity of curve representation.
527
+
528
+ V1
529
+
530
+ V2
531
+
532
+ Sinusoids Line
533
+
534
+ V0 V3
535
+
536
+ Quadratic Cubic
537
+
538
+ V5 V4
539
+
540
+ Quadratic Surface Ripple
541
+
542
+ Model MAML Meta-SGD MT-Net MUMOMAML HSML ARML
543
+
544
+ 10-shot 2.292 0.163 2.908 0.229 1.757 0.120 0.523 0.036 0.494 0.038 **0.438** **0.029**
545
+
546
+
547
+ Figure 2: In the top figure, we show the interpretation of meta-knowledge graph. The left heatmap
548
+ shows the similarity between prototypes (P0, P1) and meta-knowledge vertices (V0-V5). The right
549
+ part show the meta-knowledge graph. In the bottom table, we show the overall performance (mean
550
+ square error with 95% confidence) of 10-shot 2D regression.
551
+
552
+
553
+ -----
554
+
555
+ 5.3 FEW-SHOT CLASSIFICATION
556
+
557
+ **Dataset Description and Settings** In the few-shot classification problem, we first use the benchmark proposed in (Yao et al., 2019), where four fine-grained image classification datasets are included
558
+ (i.e., CUB-200-2011 (Bird), Describable Textures Dataset (Texture), FGVC of Aircraft (Aircraft),
559
+ and FGVCx-Fungi (Fungi)). For each few-shot classification task, it samples classes from one of four
560
+ datasets. In this paper, we call this dataset as Plain-Multi and each fine-grained dataset as subdataset.
561
+
562
+ Then, to demonstrate the effectiveness of our proposed model for handling more complex underlying
563
+ structures, in this paper, we increase the difficulty of few-shot classification problem by introducing
564
+ two image filters: blur filter and pencil filter. Similar as (Jerfel et al., 2019), for each image in PlainMulti, one artistic filters are applied to simulate a changing distribution of few-shot classification
565
+ tasks. After applying the filters, the total number of subdatasets is 12 and each tasks is sampled from
566
+ one of them. This data is named as Art-Multi. More detailed descriptions of the effect of different
567
+ filters is discussed in Appendix C.
568
+
569
+ Following the traditional meta-learning settings, all datasets are divided into meta-training, metavalidation and meta-testing classes. The traditional N-way K-shot settings are used to split training and
570
+ test set for each task. We adopt the standard four-block convolutional layers as the base learner (Finn
571
+ et al., 2017; Snell et al., 2017). The number of vertices of meta-knowledge graph for Plain-Multi
572
+ and Art-Multi datasets are set as 4 and 8, respectively. Additionally, for the miniImagenet, similar
573
+ as (Finn et al., 2018), which tasks are constructed from a single domain and do not have heterogeneity,
574
+ we compare our proposed ARML with other baselines and present the results in Appendix D.
575
+
576
+ 5.3.1 PERFORMANCE VALIDATION
577
+
578
+ **Overall Qualitative Analyses** Experimental results for Plain-Multi and Art-Multi are shown in
579
+ Table 1 and Table 2, respectively. For each dataset, the performance accuracy with 95% confidence
580
+ interval are reported. Note that, due to the space limitation, in Art-Multi dataset, we only show
581
+ the average value of each filter and the full results table are shown in Table 9 of Appendix E. In
582
+ these two tables, first, we can observe that task-specific models (MT-Net, MUMOMAML, HSML,
583
+ TADAM) significantly outperforms globally shared models (MAML, Meta-SGD, ProtoNet) in both
584
+ gradient-based and non-parametric meta-learning research lines. Second, compared ARML with
585
+ other task-specific gradient-based meta-learning methods, the better performance confirms that
586
+ ARML can model and extract task-specific information more accurately by leveraging the constructed
587
+ meta-knowledge graph. Especially, the performance gap between the ARML and HSML verifies the
588
+ benefits of relational structure compared with isolated clustering structure. Finally, as a gradient-based
589
+ meta-learning algorithm, ARML can also outperform ProtoNet and TADAM, two representative
590
+ non-parametric meta-learning algorithms.
591
+
592
+ Table 1: Overall few-shot classification results (accuracy ± 95% confidence) on Plain-Multi dataset.
593
+
594
+ |Settings|Algorithms|Data: Bird Data: Texture Data: Aircraft Data: Fungi|
595
+ |---|---|---|
596
+
597
+ |MAML 53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% MetaSGD 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% MT-Net 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 5-way MUMOMAML 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 1-shot HSML 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39% ProtoNet 54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% TADAM 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33% ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|MAML MetaSGD MT-Net MUMOMAML HSML|53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39%|
598
+ |---|---|---|
599
+ ||ProtoNet TADAM|54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33%|
600
+ ||ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
601
+
602
+ |ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
603
+ |---|---|---|
604
+ |MAML 68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% MetaSGD 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% MT-Net 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 5-way MUMOMAML 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 5-shot HSML 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80% ProtoNet 68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% TADAM 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82% ARML 73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|MAML MetaSGD MT-Net MUMOMAML HSML|68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80%|
605
+ ||ProtoNet TADAM|68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82%|
606
+ ||ARML|73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|
607
+
608
+
609
+ -----
610
+
611
+ Table 2: Overall few-shot classification results (accuracy ± 95% confidence) on Art-Multi dataset.
612
+
613
+ |Settings|Algorithms|Avg. Origninal Avg. Blur Avg. Pencil|
614
+ |---|---|---|
615
+
616
+
617
+ |MAML 42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% MetaSGD 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% MT-Net 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 5-way, 1-shot MUMOMAML 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% HSML 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36% Protonet 42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% TADAM 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34% ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|MAML MetaSGD MT-Net MUMOMAML HSML|42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36%|
618
+ |---|---|---|
619
+ ||Protonet TADAM|42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34%|
620
+ ||ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
621
+
622
+
623
+ |ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
624
+ |---|---|---|
625
+ |MAML 58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% MetaSGD 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% MT-Net 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 5-way, 5-shot MUMOMAML 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% HSML 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72% Protonet 58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% TADAM 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74% ARML 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|MAML MetaSGD MT-Net MUMOMAML HSML|58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72%|
626
+ ||Protonet TADAM|58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74%|
627
+ ||ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
628
+
629
+
630
+
631
+ **Model Ablation Study** In this section, we perform the ablation study of the proposed ARML to
632
+ demonstrate the effectiveness of each component. The results of ablation study on 5-way, 5-shot
633
+ scenario for Art-Multi dataset are presented in Table 3. In Appendix F, we also show the full results
634
+ for Art-Multi in Table 6 and the ablation study of Plain-Multi in Table 7. Specifically, to show
635
+ the effectiveness of prototype construction, in ablation I, we use the mean pooling aggregation
636
+ of each sample rather than the prototype-based relational graph to interact with meta-knowledge
637
+ graph. In ablation II, we use all samples to construct the sample-level relational graph without
638
+ using the prototype. Compared with ablation I and II, the better performance of ARML shows
639
+ that structuring samples can (1) better handling the underlying relations (2) alleviating the effect of
640
+ potential anomalies by structuring samples as prototypes.
641
+
642
+ In ablation III, we remove the meta-knowledge graph and use the prototype-based relational graph
643
+ structure with aggregator AG[q] as the task representation. The better performance of ARML demonstrates the effectiveness of meta-knowledge graph for capturing the relational structure and facilitating
644
+ the classification performance. We further remove the reconstruction loss and show the results in
645
+ ablation IV and the results demonstrate that the autoencoder structure can benefit the process of
646
+ learning the representation.
647
+
648
+ In ablation VI and VII, we change the modulate function to film (Perez et al., 2018) and tanh,
649
+ respectively. We can see that ARML is not very sensitive to the modulating function, and sigmoid
650
+ function is slightly better than other activation functions in most cases.
651
+
652
+ Table 3: Results (accuracy ± 95% confidence) of Ablation Models (5-way, 5-shot) on Art-Multi.
653
+
654
+ |Ablation Models|Ave. Original Ave. Blur Ave. Pencil|
655
+ |---|---|
656
+
657
+ |I. no prototype-based graph II. no prototype|60.80 ± 0.74% 58.36 ± 0.73% 54.79 ± 0.73% 61.34 ± 0.73% 58.34 ± 0.74% 54.81 ± 0.73%|
658
+ |---|---|
659
+
660
+ |III. no meta-knowledge graph IV. no reconstruction loss|59.99 ± 0.75% 57.79 ± 0.73% 53.68 ± 0.74% 59.07 ± 0.73% 57.20 ± 0.74% 52.45 ± 0.73%|
661
+ |---|---|
662
+
663
+ |V. tanh modulation VI. film modulation|62.34 ± 0.74% 58.58 ± 0.75% 54.01 ± 0.74% 60.06 ± 0.75% 57.47 ± 0.73% 52.06 ± 0.74%|
664
+ |---|---|
665
+
666
+ |ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
667
+ |---|---|
668
+
669
+
670
+ 5.3.2 ANALYSIS OF CONSTRUCTED META-KNOWLEDGE GRAPH
671
+
672
+ In this section, we conduct extensive analysis for the constructed meta-knowledge graph, which is
673
+ regarded as the key component in ARML. Due to the space limit, we only present the results on ArtMulti datasets. For Plain-Multi, the analysis with similar observations are discussed in Appendix G.
674
+
675
+
676
+ -----
677
+
678
+ **Performance v.s. Vertice Numbers** We first investigate the impact of vertice numbers in metaknowledge graph. The results are shown in Table 4. From the results, we can notice that the
679
+ performance saturates as the number of vertices researches around 8. One potential reason is that 8
680
+ vertices is enough to capture the potential relations. If we have a larger datasets with more complex
681
+ relations, more vertices may be needed. In addition, if the meta-knowledge graph do not have enough
682
+ vertices, the worse performance suggests that the graph may not be able to capture enough relations
683
+ across tasks.
684
+
685
+ Table 4: Sensitivity analysis with different # of vertices in meta-knowledge graph (5-way, 5-shot).
686
+
687
+ |# of vertices|Ave. Original Ave. Blur Ave. Pencil|
688
+ |---|---|
689
+
690
+
691
+ |4 8 12 16 20|61.18 ± 0.72% 58.13 ± 0.73% 54.88 ± 0.75% 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73% 61.66 ± 0.73% 58.61 ± 0.72% 55.07 ± 0.74% 61.75 ± 0.73% 58.67 ± 0.74% 55.26 ± 0.73% 61.91 ± 0.74% 58.92 ± 0.73% 55.24 ± 0.72%|
692
+ |---|---|
693
+
694
+
695
+
696
+ **Model Interpretation Analysis of Meta-Knowledge Graph** We then analyze the learned metaknowledge graph. For each subdataset, we randomly select one task as exemplary. For each task,
697
+ in the left part of Figure 3 we show the similarity heatmap between prototypes and vertices in
698
+ meta-knowledge graph, where deeper color means higher similarity. V0-V8 and P1-P5 denotes
699
+ the different vertices and prototypes, respectively. The meta-knowledge graph is also illustrated
700
+ in the right part. Similar as the graph in 2D regression, we set a threshold to filter links with low
701
+ similarity and illustrate the rest of them. First, We can see that the V1 is mainly activated by bird
702
+ and aircraft (including all filters), which may reflect the shape similarity between bird and aircraft.
703
+ Second, V2, V3, V4 are firstly activated by texture and they form a loop in the meta-knowledge
704
+ graph. Especially, V2 also benefits images with blur and pencil filters. Thus, V2 may represent the
705
+ main texture and facilitate the training process on other subdatasets. The meta-knowledge graph also
706
+ shows the importance of V2 since it is connected with almost all other vertices. Third, when we use
707
+ blur filter, in most cases (bird blur, texture blur, fungi blur), V7 is activated. Thus, V7 may show the
708
+ similarity of images with blur filter. In addition, the connection between V7 and V2 and V3 show that
709
+ classify blur images may depend on the texture information. Fourth, V6 (activated by aircraft mostly)
710
+ connects with V2 and V3, justifying the importance of texture information to classify the aircrafts.
711
+
712
+ V1
713
+
714
+ V2
715
+
716
+ Bird Texture Aircraft Fungi
717
+
718
+ V0 V3
719
+
720
+ Bird Blur Texture Blur Aircraft Blur Fungi Blur V7
721
+
722
+ V4
723
+
724
+ V6
725
+
726
+ V5
727
+
728
+ Bird Pencil Texture Pencil Aircraft Pencil Fungi Pencil
729
+
730
+
731
+ Figure 3: Interpretation of meta-knowledge graph on Art-Multi dataset. For each subdataset, we
732
+ randomly select one task from them. In the left, we show the similarity heatmap between prototypes
733
+ (P0-P5) and meta-knowledge vertices (V0-V7). In the right part, we show the meta-knowledge graph.
734
+
735
+ 6 CONCLUSION
736
+
737
+ In this paper, to improve the effectiveness of meta-learning for handling heterogeneous task, we
738
+ propose a new framework called ARML, which automatically extract relation across tasks and
739
+ construct a meta-knowledge graph. When a new task comes in, it can quickly find the most relevant
740
+ relations through the meta-knowledge graph and use this knowledge to facilitate its training process.
741
+ The experiments demonstrate the effectiveness of our proposed algorithm.
742
+
743
+
744
+ -----
745
+
746
+ REFERENCES
747
+
748
+ Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul,
749
+ Brendan Shillingford, and Nando De Freitas. Learning to learn by gradient descent by gradient
750
+ descent. In NeurIPS, pp. 3981–3989, 2016.
751
+
752
+ Chelsea Finn and Sergey Levine. Meta-learning and universality: Deep representations and gradient
753
+ descent can approximate any learning algorithm. In ICLR, 2018.
754
+
755
+ Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of
756
+ deep networks. In ICML, pp. 1126–1135, 2017.
757
+
758
+ Chelsea Finn, Kelvin Xu, and Sergey Levine. Probabilistic model-agnostic meta-learning. In NeurIPS,
759
+ 2018.
760
+
761
+ Sebastian Flennerhag, Pablo G Moreno, Neil D Lawrence, and Andreas Damianou. Transferring
762
+ knowledge across learning processes. ICLR, 2019.
763
+
764
+ Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural
765
+ message passing for quantum chemistry. In ICML, pp. 1263–1272. JMLR. org, 2017.
766
+
767
+ Jonathan Gordon, John Bronskill, Matthias Bauer, Sebastian Nowozin, and Richard E Turner. Metalearning probabilistic inference for prediction. In ICLR, 2019.
768
+
769
+ Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, and Thomas Griffiths. Recasting gradientbased meta-learning as hierarchical bayes. In ICLR, 2018.
770
+
771
+ Jiatao Gu, Yong Wang, Yun Chen, Kyunghyun Cho, and Victor OK Li. Meta-learning for low-resource
772
+ neural machine translation. In EMNLP, 2018.
773
+
774
+ Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In
775
+ _NeurIPS, pp. 1024–1034, 2017._
776
+
777
+ Ghassen Jerfel, Erin Grant, Thomas L Griffiths, and Katherine Heller. Reconciling meta-learning and
778
+ continual learning with online mixtures of tasks. NeurIPS, 2019.
779
+
780
+ Bingyi Kang, Zhuang Liu, Xin Wang, Fisher Yu, Jiashi Feng, and Trevor Darrell. Few-shot object
781
+ detection via feature reweighting. In ICCV, 2019.
782
+
783
+ Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks.
784
+ In ICLR, 2017.
785
+
786
+ Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and
787
+ subspace. In ICML, pp. 2933–2942, 2018.
788
+
789
+ Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-sgd: Learning to learn quickly for few
790
+ shot learning. arXiv preprint arXiv:1707.09835, 2017.
791
+
792
+ Zhaojiang Lin, Andrea Madotto, Chien-Sheng Wu, and Pascale Fung. Personalizing dialogue agents
793
+ via meta-learning. 2019.
794
+
795
+ Ming-Yu Liu, Xun Huang, Arun Mallya, Tero Karras, Timo Aila, Jaakko Lehtinen, and Jan Kautz.
796
+ Few-shot unsupervised image-to-image translation. arXiv preprint arXiv:1905.01723, 2019.
797
+
798
+ Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive metalearner. ICLR, 2018.
799
+
800
+ Alex Nichol and John Schulman. Reptile: a scalable metalearning algorithm. arXiv preprint
801
+ _arXiv:1803.02999, 2018._
802
+
803
+ Boris Oreshkin, Pau Rodr´ıguez Lopez, and Alexandre Lacoste. Tadam: Task dependent adaptive´
804
+ metric for improved few-shot learning. In NeurIPS, pp. 721–731, 2018.
805
+
806
+ Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, and Aaron C. Courville. Film: Visual
807
+ reasoning with a general conditioning layer. In AAAI, 2018.
808
+
809
+
810
+ -----
811
+
812
+ Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. ICLR, 2016.
813
+
814
+ Andrei A Rusu, Dushyant Rao, Jakub Sygnowski, Oriol Vinyals, Razvan Pascanu, Simon Osindero,
815
+ and Raia Hadsell. Meta-learning with latent embedding optimization. In ICLR, 2019.
816
+
817
+ Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In
818
+ _NeurIPS, pp. 4077–4087, 2017._
819
+
820
+ Petar Velickoviˇ c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua´
821
+ Bengio. Graph attention networks. In ICLR, 2018.
822
+
823
+ Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al. Matching networks for one
824
+ shot learning. In NeurIPS, pp. 3630–3638, 2016.
825
+
826
+ Risto Vuorio, Shao-Hua Sun, Hexiang Hu, and Joseph J Lim. Toward multimodal model-agnostic
827
+ meta-learning. NeurIPS, 2019.
828
+
829
+ Xin Wang, Fisher Yu, Ruth Wang, Trevor Darrell, and Joseph E Gonzalez. Tafe-net: Task-aware
830
+ feature embeddings for low shot learning. In CVPR, pp. 1831–1840, 2019.
831
+
832
+ Flood Sung Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales.
833
+ Learning to compare: Relation network for few-shot learning. In CVPR, 2018.
834
+
835
+ Huaxiu Yao, Ying Wei, Junzhou Huang, and Zhenhui Li. Hierarchically structured meta-learning. In
836
+ _ICML, pp. 7045–7054, 2019._
837
+
838
+ Jaesik Yoon, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn.
839
+ Bayesian model-agnostic meta-learning. In NeurIPS, pp. 7343–7353, 2018.
840
+
841
+ Sung Whan Yoon, Jun Seo, and Jaekyun Moon. Tapnet: Neural network augmented with task-adaptive
842
+ projection for few-shot learning. In ICML, 2019.
843
+
844
+
845
+ -----
846
+
847
+ A ALGORITHM IN META-TESTING PROCESS
848
+
849
+ **Algorithm 2 Meta-Testing Process of ARML**
850
+
851
+ **Require: Training data** _t_ [of a new task][ T][t]
852
+ _D[tr]_
853
+
854
+ 1: Construct the prototype-based relational graph Rt by computing prototype in equation 2 and
855
+ weight in equation 4
856
+
857
+ 2: Compute the similarity between each prototype and meta-knowledge vertice in equation 6 and
858
+ construct the super-graph St
859
+
860
+ 3: Apply GNN on super-graph St and get the updated prototype representation **C[ˆ]** _Rt_
861
+
862
+ 4: Aggregate CRt in equation 8, **C[ˆ]** _Rt in equation 9 and get the representations qt, tt_
863
+
864
+ 5: Compute the task-specific initialization θ0t in equation 10
865
+ 6: Update parameters θt = θ0t − _α∇θL(fθ, Dt[tr][)]_
866
+
867
+
868
+ B HYPERPARAMETERS SETTINGS
869
+
870
+ B.1 2D REGRESSION
871
+
872
+ In 2D regression problem, we set the inner-loop stepsize (i.e., α) and outer-loop stepsize (i.e., β) as
873
+ 0.001 and 0.001, respectively. The embedding function E is set as one layer with 40 neurons. The
874
+ autoencoder aggregator is constructed by the gated recurrent structures. We set the meta-batch size as
875
+ 25 and the inner loop gradient steps as 5.
876
+
877
+ B.2 FEW-SHOT IMAGE CLASSIFICATION
878
+
879
+ In few-shot image classification, for both Plain-Multi and Art-Multi datasets, we set the corresponding
880
+ inner stepsize (i.e., α) as 0.001 and the outer stepsize (i.e., β) as 0.01. For the embedding function E,
881
+ we employ two convolutional layers with 3 × 3 filters. The channel size of these two convolutional
882
+ layers are 32. After convolutional layers, we use two fully connected layers with 384 and 64 neurons
883
+ for each layer. Similar as the hyperparameter settings in 2D regression, the autoencoder aggregator
884
+ is constructed by the gated recurrent structures, i.e., AG[t], AG[t]dec [AG][q][,][ AG]dec[q] [are all GRUs. The]
885
+ meta-batch size is set as 4. For the inner loop, we use 5 gradient steps.
886
+
887
+ B.3 DETAILED BASELINE SETTINGS
888
+
889
+ For the gradient-based baselines (i.e., MAML, MetaSGD, MT-Net, BMAML. MUMOMAML,
890
+ HSML), we use the same inner loop stepsize and outer loop stepsize rate as our ARML. As for
891
+ non-parametric based meta-learning algorithms, both TADAM and Prototypical network, we use the
892
+ same meta-training and meta-testing process as gradient-based models. Additionally, TADAM uses
893
+ the same embedding function E as ARML for fair comparison (i.e., similar expressive ability).
894
+
895
+ C ADDITIONAL DISCUSSION OF DATASETS
896
+
897
+ In this dataset, we use pencil and blur filers to change the task distribution. To investigate the effect
898
+ of pencil and blur filters, we provide one example in Figure 4. We can observe that different filters
899
+ result in different data distributions. All used filter are provided by OpenCV[1].
900
+
901
+ D RESULTS ON MINIIMAGENET
902
+
903
+ For miniimagenet, since it do not have the characteristic of task heterogeneity, we show the results in
904
+ Table 5. In this table, we compare the MiniImagenet dataset with other gradient-based meta-learning
905
+ models (the first four baselines are globally shared models and the next four are task-specific models).
906
+ Similar as (Finn et al., 2018), we also apply the standard 4-block convolutional layers for each
907
+
908
+ 1https://opencv.org/
909
+
910
+
911
+ -----
912
+
913
+ (a) : Plain Image (b) : with blur filter (c) : with pencil filter
914
+
915
+ Figure 4: Effect of different filters.
916
+
917
+ baseline. For MT-Net, we use the reported results in (Yao et al., 2019), which control the model with
918
+ the same expressive power. The results indicate that our proposed ARML can outperform the original
919
+ MAML and achieves comparable performance with task-specific models (e.g., MT-Net, PLATIPUS,
920
+ HSML). Most task-specific models achieve the similar performance on the standard benchmark due
921
+ to the homogeneity between tasks.
922
+
923
+ Table 5: Performance comparison on the 5-way, 1-shot MiniImagenet dataset.
924
+
925
+ |Algorithms|5-way 1-shot Accuracy|
926
+ |---|---|
927
+
928
+ |MAML (Finn et al., 2017) LLAMA (Finn & Levine, 2018) Reptile (Nichol & Schulman, 2018) MetaSGD (Li et al., 2017)|48.70 1.84% ± 49.40 1.83% ± 49.97 0.32% ± 50.47 1.87% ±|
929
+ |---|---|
930
+
931
+ |MT-Net (Lee & Choi, 2018) MUMOMAML (Vuorio et al., 2019) HSML (Yao et al., 2019) PLATIPUS (Finn et al., 2018)|49.75 1.83% ± 49.86 1.85% ± 50.38 1.85% ± 50.13 1.86% ±|
932
+ |---|---|
933
+
934
+ |ARML|50.42 1.73% ±|
935
+ |---|---|
936
+
937
+
938
+ E ADDITIONAL RESULTS OF FEW-SHOT IMAGE CLASSIFICATION
939
+
940
+ E.1 FULL OVERALL RESULTS TABLE OF ART-MULTI DATASET
941
+
942
+ We provide the full results table of Art-Multi Dataset in Table 9. In this table, we can see our proposed
943
+ ARML outperforms almost all baselines in every sub-datasets.
944
+
945
+ F FURTHER INVESTIGATION OF ABLATION STUDY
946
+
947
+ In this section, we first show the full evaluation results of model ablation study on Art-Multi dataset
948
+ in 6. Note that, for the tanh activation (ablation model V), the performance is similar as applying
949
+ the sigmoid activation. On some subdatasets, the results are even better. We choose the sigmoid
950
+ activation for ARML because it achieves overall better performance than the tanh activation on more
951
+ subdatasets. Then, for Plain-Multi dataset, we show the results in 7. The conclusion of ablation study
952
+ in Plain-Multi dataset is similar as the conclusion drawn from the results on Art-Multi dataset. The
953
+ improvement on these two datasets verifies the necessity of the joint framework in ARML.
954
+
955
+ G ADDITIONAL ANALYSIS OF META-KNOWLEDGE GRAPH
956
+
957
+ In this section, we add more interpretation analysis of meta-knowledge graph. First, we show the full
958
+ evaluation results of sensitivity analysis on Art-Multi dataset in Table 8.
959
+
960
+
961
+ -----
962
+
963
+ Table 6: Full evaluation results of model ablation study on Art-Multi dataset. B, T, A, F represent
964
+ bird, texture, aircraft, fungi, respectively. Plain means original image.
965
+
966
+ |Model|B Plain B Blur B Pencil T Plain T Blur T Pencil|
967
+ |---|---|
968
+
969
+
970
+ |I. no prototype-based graph II. no prototype|72.08% 71.06% 66.83% 45.23% 39.97% 41.67% 72.99% 70.92% 67.19% 45.17% 40.05% 41.04%|
971
+ |---|---|
972
+
973
+
974
+ |III. no meta-knowledge graph IV. no reconstruction loss|70.79% 69.53% 64.87% 43.37% 39.86% 41.23% 70.82% 69.87% 65.32% 44.02% 40.18% 40.52%|
975
+ |---|---|
976
+
977
+
978
+ |V. tanh VI. film|72.70% 69.53% 66.85% 45.81% 40.79% 38.64% 71.52% 68.70% 64.23% 43.83% 40.52% 39.49%|
979
+ |---|---|
980
+
981
+
982
+ |Model|A Plain A Blur A Pencil F Plain F Blur F Pencil|
983
+ |---|---|
984
+
985
+
986
+ |I. no prototype-based graph II. no prototype|70.06% 68.02% 60.66% 55.81% 54.39% 50.01% 71.10% 67.59% 61.07% 56.11% 54.82% 49.95%|
987
+ |---|---|
988
+
989
+
990
+ |III. no meta-knowledge graph IV. no reconstruction loss|69.97% 68.03% 59.72% 55.84% 53.72% 48.91% 66.83% 65.73% 55.98% 54.62% 53.02% 48.01%|
991
+ |---|---|
992
+
993
+
994
+ |V. tanh VI. film|73.96% 69.70% 60.75% 56.87% 54.30% 49.82% 69.13% 66.93% 55.59% 55.77% 53.72% 48.92%|
995
+ |---|---|
996
+
997
+
998
+ |ARML|71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
999
+ |---|---|
1000
+
1001
+
1002
+ ARML **73.05%** **71.31%** **67.14%** 45.32% 40.15% **41.98%**
1003
+
1004
+
1005
+ Table 7: Results of Model Ablation (5-way, 5-shot results) on Plain-Multi dataset.
1006
+
1007
+ |Ablation Models|Bird|Texture|Aircraft|Fungi|
1008
+ |---|---|---|---|---|
1009
+
1010
+ |I. no sample-level graph II. no prototype|71.96 ± 0.72% 72.86 ± 0.74%|48.79 ± 0.67% 49.03 ± 0.69%|74.02 ± 0.65% 74.36 ± 0.65%|56.83 ± 0.80% 57.02 ± 0.81%|
1011
+ |---|---|---|---|---|
1012
+
1013
+ |III. no meta-knowledge graph IV. no reconstruction loss|71.23 ± 0.75% 70.99 ± 0.74%|47.96 ± 0.68% 48.03 ± 0.69%|73.71 ± 0.69% 69.86 ± 0.66%|55.97 ± 0.82% 55.78 ± 0.83%|
1014
+ |---|---|---|---|---|
1015
+
1016
+ |V. tanh VI. film|73.45 ± 0.71% 72.95 ± 0.73%|49.23 ± 0.66% 49.18 ± 0.69%|74.39 ± 0.65% 73.82 ± 0.68%|57.38 ± 0.80% 56.89 ± 0.80%|
1017
+ |---|---|---|---|---|
1018
+
1019
+ |ARML|73.34 ± 0.70%|49.67 ± 0.67%|74.88 ± 0.64%|57.55 ± 0.82%|
1020
+ |---|---|---|---|---|
1021
+
1022
+
1023
+ Then, we analyze the meta-knowledge graph on Plain-Multi dataset by visualizing the learned metaknowledge graph on Plain-Multi dataset (as shown in Figure 5). In this figure, we can see that
1024
+ different subdatasets activate different vertices. Specifically, V2, which is mainly activated by texture,
1025
+ plays a significantly important role in aircraft and fungi. Thus, V2 connects with V3 and V1 in the
1026
+ meta-knowledge graph, which are mainly activated by fungi and aircraft, respectively. In addition,
1027
+ V0 is also activated by aircraft because of the similar contour between aircraft and bird. Furthermore,
1028
+ in meta-knowledge graph, V0 connects with V3, which shows the similarity of environment between
1029
+ bird images and fungi images.
1030
+
1031
+
1032
+ -----
1033
+
1034
+ Bird
1035
+
1036
+
1037
+ Texture
1038
+
1039
+
1040
+ V1
1041
+
1042
+ V2
1043
+
1044
+ V0
1045
+
1046
+ V3
1047
+
1048
+
1049
+ Aircraft Fungi
1050
+
1051
+ Figure 5: Interpretation of meta-knowledge graph on Plain-Multi dataset. For each subdataset, one
1052
+ task is randomly selected from them. In the left figure, we show the similarity heatmap between
1053
+ prototypes (P1-P5) and meta-knowledge vertices (denoted as E1-E4), where deeper color means
1054
+ higher similarity. In the right part, we show the meta-knowledge graph, where a threshold is also set
1055
+ to filter low similarity links.
1056
+
1057
+ Table 8: Full evaluation results of performance v.s. # vertices of meta-knowledge graph on Art-Multi.
1058
+ B, T, A, F represent bird, texture, aircraft, fungi, respectively. Plain means original image.
1059
+
1060
+ |# of Vertices|B Plain B Blur B Pencil T Plain T Blur T Pencil|
1061
+ |---|---|
1062
+
1063
+ |# of Vertices|A Plain A Blur A Pencil F Plain F Blur F Pencil|
1064
+ |---|---|
1065
+
1066
+ |4 8 12 16 20|70.98% 67.36% 60.46% 56.07% 53.77% 50.08% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53% 71.78% 67.26% 60.97% 56.87% 55.14% 50.86% 71.96% 68.55% 61.14% 56.76% 54.54% 49.41% 72.02% 68.29% 60.59% 55.95% 54.53% 50.13%|
1067
+ |---|---|
1068
+
1069
+
1070
+ 4 72.29% 70.36% 67.88% 45.37% 41.05% 41.43%
1071
+ 8 73.05% 71.31% 67.14% 45.32% 40.15% 41.98%
1072
+ 12 73.45% 70.64% 67.41% 44.53% 41.41% 41.05%
1073
+ 16 72.68% 70.18% 68.34% 45.63% 41.43% 42.18%
1074
+ 20 73.41% 71.07% 68.64% 46.26% 41.80% 41.61%
1075
+
1076
+
1077
+ -----
1078
+
1079
+ |55.27% 52.62% 48.58% 30.57% 28.65% 28.39% 45.59% 42.24% 34.52% 39.37% 38.58% 35.38% 55.23% 53.08% 48.18% 29.28% 28.70% 28.38% 51.24% 47.29% 35.98% 41.08% 40.38% 36.30% 56.99% 54.21% 50.25% 32.13% 29.63% 29.23% 43.64% 40.08% 33.73% 43.02% 42.64% 37.96% 57.73% 53.18% 50.96% 31.88% 29.72% 29.90% 49.95% 43.36% 39.61% 42.97% 40.08% 36.52% 58.15% 53.20% 51.09% 32.01% 30.21% 30.17% 49.98% 45.79% 40.87% 42.58% 41.29% 37.01%|53.67% 50.98% 46.66% 31.37% 29.08% 28.48% 45.54% 43.94% 35.49% 37.71% 38.00% 34.36% 54.76% 52.18% 48.85% 32.03% 29.90% 30.82% 50.42% 47.59% 40.17% 41.73% 40.09% 36.27%|59.67% 54.89% 52.97% 32.31% 30.77% 31.51% 51.99% 47.92% 41.93% 44.69% 42.13% 38.36%|
1080
+ |---|---|---|
1081
+ |MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
1082
+
1083
+ |71.51% 68.65% 63.93% 42.96% 39.59% 38.87% 64.68% 62.54% 49.20% 54.08% 52.02% 46.39% 71.31% 68.73% 64.33% 41.89% 37.79% 37.91% 64.88% 63.36% 52.31% 53.18% 52.26% 46.43% 71.18% 69.29% 68.28% 43.23% 39.42% 39.20% 63.39% 58.29% 46.12% 54.01% 51.70% 47.02% 71.57% 70.50% 64.57% 44.57% 40.31% 40.07% 63.36% 61.55% 52.17% 54.89% 52.82% 47.79% 71.75% 69.31% 65.62% 44.68% 40.13% 41.33% 70.12% 67.63% 59.40% 55.97% 54.60% 49.40%|70.42% 67.90% 61.82% 44.78% 38.43% 38.40% 65.84% 63.41% 54.08% 51.45% 50.56% 46.33% 70.08% 69.05% 65.45% 44.93% 41.80% 40.18% 70.35% 68.56% 59.09% 56.04% 54.04% 47.85%|73.05% 71.31% 67.14% 45.32% 40.15% 41.98% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
1084
+ |---|---|---|
1085
+ |MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
1086
+
1087
+
1088
+ F Pencil
1089
+
1090
+ F Blur
1091
+
1092
+ F Plain
1093
+ A Pencil
1094
+
1095
+ A Blur
1096
+
1097
+ A Plain
1098
+ T Pencil
1099
+
1100
+ T Blur
1101
+
1102
+ T Plain
1103
+ B Pencil
1104
+
1105
+ B Blur
1106
+
1107
+ B Plain
1108
+
1109
+ Algorithms
1110
+ Settings
1111
+
1112
+
1113
+ %
1114
+ **36.38**
1115
+
1116
+ %
1117
+ **13.42**
1118
+
1119
+ %
1120
+ **69.44**
1121
+
1122
+ %
1123
+ **93.41**
1124
+
1125
+ %
1126
+ **92.47**
1127
+
1128
+ %
1129
+ **99.51**
1130
+
1131
+ %
1132
+ **51.31**
1133
+
1134
+ %
1135
+ **77.30**
1136
+
1137
+ %
1138
+ **31.32**
1139
+
1140
+ %
1141
+ **97.52**
1142
+
1143
+ %
1144
+ **89.54**
1145
+
1146
+ %
1147
+ **67.59**
1148
+
1149
+ ARML
1150
+
1151
+
1152
+ %
1153
+ **53.50**
1154
+
1155
+ %
1156
+ **87.54**
1157
+
1158
+ %
1159
+ **83.56**
1160
+
1161
+ %
1162
+ **41.61**
1163
+
1164
+ %
1165
+ **59.68**
1166
+
1167
+ %
1168
+ **89.71**
1169
+
1170
+ %
1171
+ **98.41**
1172
+ 15%.40
1173
+
1174
+ %
1175
+ **32.45**
1176
+
1177
+ %
1178
+ **14.67**
1179
+
1180
+ %
1181
+ **31.71**
1182
+
1183
+ %
1184
+ **05.73**
1185
+
1186
+ ARML
1187
+
1188
+
1189
+ -----
1190
+
ai_scientist/fewshot_examples/2_carpe_diem.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"This paper proposes Recency Bias, an adaptive mini batch selection method for training deep neural networks. To select informative minibatches for training, the proposed method maintains a fixed size sliding window of past model predictions for each data sample. At a given iteration, samples which have highly inconsistent predictions within the sliding window are added to the minibatch. The main contribution of this paper is the introduction of a sliding window to remember past model predictions, as an improvement over the SOTA approach: Active Bias, which maintains a growing window of model predictions. Empirical studies are performed to show the superiority of Recency Bias over two SOTA approaches. Results are shown on the task of (1) image classification from scratch and (2) image classification by fine-tuning pretrained networks.\",\n \"Strengths\": [\n \"The idea of using a sliding window over a growing window in active batch selection is interesting.\",\n \"Overall, the paper is well written. In particular, the Related Work section has a nice flow and puts the proposed method into context. Despite the method having limited novelty (sliding window instead of a growing window), the method has been well motivated by pointing out the limitations in SOTA methods.\",\n \"The results section is well structured. It's nice to see hyperparameter tuning results; and loss convergence graphs in various learning settings for each dataset.\"\n ],\n \"Weaknesses\": [\n \"The key concern about the paper is the lack of rigorous experimentation to study the usefulness of the proposed method. Despite the paper stating that there have been earlier work (Joseph et al, 2019 and Wang et al, 2019) that attempt mini-batch selection, the paper does not compare with them. This is limiting. Further, since the proposed method is not specific to the domain of images, evaluating it on tasks other than image classification, such as text classification for instance, would have helped validate its applicability across domains.\",\n \"Considering the limited results, a deeper analysis of the proposed method would have been nice. The idea of a sliding window over a growing window is a generic one, and there have been many efforts to theoretically analyze active learning over the last two decades. How does the proposed method fit in there? (For e.g., how does the expected model variance change in this setting?) Some form of theoretical/analytical reasoning behind the effectiveness of recency bias (which is missing) would provide greater insights to the community and facilitate further research in this direction.\",\n \"The claim of 20.5% reduction in test error mentioned in the abstract has not been clearly addressed and pointed out in the results section of the paper.\",\n \"The results would have been more complete if results were shown in a setting where just recency bias is used without the use of the selection pressure parameter. In other words, an ablation study on the effect of the selection pressure parameter would have been very useful.\",\n \"The intuition behind the method is described well, however, the proposed method would have been really solidified if it were analysed in the context of a simple machine learning problem (such as logistic regression). As an example, verifying if the chosen minibatch samples are actually close to the decision boundary of a model (even if the model is very simple) would have helped analyze the proposed method well.\"\n ],\n \"Originality\": 3,\n \"Quality\": 2,\n \"Clarity\": 4,\n \"Significance\": 2,\n \"Questions\": [\n \"How important is the warm-up phase to the proposed method? Considering the paper states that this is required to get good estimates of the quantization index of the samples, some ablation studies on reducing/increasing the warm-up phase and showing the results would have been useful to understand this.\",\n \"Fig 4: Why are there sharp dips periodically in all the graphs? What do these correspond to?\",\n \"The results are not conclusively in favor of the proposed method, and only is marginally better than the competitors. Why does online batch perform consistently than the proposed method? There is no discussion of these inferences from the results.\"\n ],\n \"Limitations\": [\n \"The primary concern is about the strength of the experimental results, which showed only a modest benefit on relatively simple datasets.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 2,\n \"Presentation\": 3,\n \"Contribution\": 2,\n \"Overall\": 4,\n \"Confidence\": 3,\n \"Decision\": \"Reject\"\n}"
3
+ }
ai_scientist/fewshot_examples/2_carpe_diem.pdf ADDED
Binary file (858 kB). View file
 
ai_scientist/fewshot_examples/2_carpe_diem.txt ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CARPE DIEM, SEIZE THE SAMPLES UNCERTAIN “AT
2
+ ## THE MOMENT” FOR ADAPTIVE BATCH SELECTION
3
+
4
+ **Anonymous authors**
5
+ Paper under double-blind review
6
+
7
+ ABSTRACT
8
+
9
+ The performance of deep neural networks is significantly affected by how well
10
+ mini-batches are constructed. In this paper, we propose a novel adaptive batch
11
+ selection algorithm called Recency Bias that exploits the uncertain samples
12
+ predicted inconsistently in recent iterations. The historical label predictions of
13
+ each sample are used to evaluate its predictive uncertainty within a sliding window.
14
+ By taking advantage of this design, Recency Bias not only accelerates the training
15
+ step but also achieves a more accurate network. We demonstrate the superiority
16
+ of Recency Bias by extensive evaluation on two independent tasks. Compared with
17
+ existing batch selection methods, the results showed that Recency Bias reduced
18
+ the test error by up to 20.5% in a fixed wall-clock training time. At the same time,
19
+ it improved the training time by up to 59.3% to reach the same test error.
20
+
21
+ 1 INTRODUCTION
22
+
23
+ Stochastic gradient descent (SGD) for randomly selected mini-batch samples is commonly used to
24
+ train deep network netowrks (DNNs). However, many recent studies have pointed out that the performance of DNNs is heavily dependent on how well the mini-batch samples are selected (Shrivastava
25
+ et al., 2016; Chang et al., 2017; Katharopoulos & Fleuret, 2018). In earlier approaches, a sample’s difficulty is employed to identify proper mini-batch samples, and these approaches achieve
26
+ a more accurate and robust network (Han et al., 2018) or expedite the training convergence of
27
+ SGD (Loshchilov & Hutter, 2016). However, the two opposing difficulty-based strategies, i.e., preferring easy samples (Kumar et al., 2010; Han et al., 2018) versus hard samples (Loshchilov & Hutter,
28
+ 2016; Shrivastava et al., 2016), work well in different situations. Thus, for practical reasons to cover
29
+ more diverse situations, recent approaches begin to exploit a sample’s uncertainty that indicates the
30
+ consistency of previous predictions (Chang et al., 2017; Song et al., 2019).
31
+
32
+ An important question here is how to evaluate the sample’s uncertainty based on its historical
33
+ predictions during the training process. Intuitively, because a series of historical predictions can
34
+ be seen as a series of data indexed in chronological order, the uncertainty can be measured based on
35
+ _two forms of handling time-series observations: (i) a growing window (Figure 1(a)) that consistently_
36
+ increases the size of a window to use all available observations and (ii) a sliding window (Figure 1(b))
37
+ that maintains a window of a fixed size on the most recent observations by deleting outdated ones.
38
+ While the state-of-the-art algorithm, Active Bias (Chang et al., 2017), adopts the growing window,
39
+ we propose to use the sliding window in this paper.
40
+
41
+ |Historical observations|Col2|
42
+ |---|---|
43
+ |||
44
+ |||
45
+
46
+
47
+ Historical observations Historical observations
48
+
49
+ Growing Sliding
50
+
51
+ All available observations Outdated observations Recent observations
52
+
53
+
54
+ (a) Growing Window. (b) Sliding Window.
55
+
56
+ Figure 1: Two forms of handling the time-series observations.
57
+
58
+ In more detail, Active Bias recognizes uncertain samples based on the inconsistency of the predictions
59
+ in the entire history of past SGD iterations. Then, it emphasizes such uncertain samples by choosing
60
+ them with high probability for the next mini-batch. However, according to our experiments presented
61
+
62
+
63
+ -----
64
+
65
+ |… Horse Horse Horse|Col2|
66
+ |---|---|
67
+
68
+ |… Deer Deer Deer|Col2|
69
+ |---|---|
70
+
71
+
72
+ Images Inconsistent Predictions Consistent Predictions Sample Method
73
+
74
+ (Horse) History Uncertainty
75
+
76
+ Outdated Recent (too easy)
77
+
78
+ **High**
79
+
80
+ Horse Deer Horse Deer Deer Horse Deer … Horse Horse … Horse Active Bias
81
+
82
+ **Low**
83
+
84
+ Outdated Recent (too hard)
85
+
86
+ Deer Horse Horse Deer Horse Deer Horse … Deer Deer … Deer **High** **Recency Bias**
87
+
88
+ **Low**
89
+
90
+ Previous Training Iterations
91
+
92
+
93
+ Figure 2: The difference in sample uncertainty estimated by Active Bias and Recency Bias.
94
+
95
+ in Section 5.2, such uncertain samples slowed down the convergence speed of training, though they
96
+ ultimately reduced the generalization error. This weakness is attributed to the inherent limitation of
97
+ the growing window, where older observations could be too outdated (Torgo, 2011). In other words,
98
+ the outdated predictions no longer represent a network’s current behavior. As illustrated in Figure
99
+ 2, when the label predictions of two samples were inconsistent for a long time, Active Bias invariably
100
+ regards them as highly uncertain, although their recent label predictions become consistent along
101
+ with the network’s training progress. This characteristic evidently entails the risk of emphasizing
102
+ uninformative samples that are too easy or too hard at the current moment, thereby slowing down
103
+ the convergence speed of training.
104
+
105
+ Therefore, we propose a simple but effective batch selection method, called Recency Bias, that takes
106
+ advantage of the sliding window to evaluate the uncertainty in fresher observations. As opposed to
107
+ _Active Bias, Recency Bias excludes the outdated predictions by managing a sliding window of a fixed_
108
+ size and picks up the samples predicted inconsistently within the sliding window. Thus, as shown
109
+ in Figure 2, the two samples uninformative at the moment are no longer selected by Recency Bias
110
+ simply because their recent predictions are consistent. Consequently, since informative samples are
111
+ effectively selected throughout the training process, this strategy not only accelerates the training
112
+ speed but also leads to a more accurate network.
113
+
114
+ To validate the superiority of Recency Bias, two popular convolutional neural networks (CNNs) were
115
+ trained for two independent tasks: image classification and fine tuning. We compared Recency Bias
116
+ with not only random batch selection (baseline) but also two state-of-the-art batch selection strategies.
117
+ Compared with three batch selection strategies, Recency Bias provided a relative reduction of test
118
+ error by 1.81%–20.5% in a fixed wall-clock training time. At the same time, it significantly reduced
119
+ the execution time by 24.6%–59.3% to reach the same test error.
120
+
121
+ 2 RELATED WORK
122
+
123
+ Let D = {(xi, yi)|1 ≤ _i ≤_ _N_ _} be the entire training dataset composed of a sample xi with its_
124
+ true label yi, where N is the total number of training samples. Then, a straightforward strategy to
125
+ construct a mini-batch = (xi, yi) 1 _i_ _b_ is to select b samples uniformly at random (i.e.,
126
+ _M_ _{_ _|_ _≤_ _≤_ _}_
127
+ _P_ (xi ) = 1/N ) from the training dataset .
128
+ _|D_ _D_
129
+
130
+ Because not all samples have an equal impact on training, many research efforts have been devoted
131
+ to develop advanced sampling schemes. Bengio et al. (2009) first took easy samples and then
132
+ gradually increased the difficulty of samples using heuristic rules. Kumar et al. (2010) determined the
133
+ easiness of the samples using their prediction errors. Recently, Tsvetkov et al. (2016) used Bayesian
134
+ optimization to learn an optimal curriculum for training dense, distributed word representations.
135
+ Sachan & Xing (2016) emphasized that the right curriculum must introduce a small number of the
136
+ samples dissimilar to those previously seen. Fan et al. (2017) proposed a neural data filter based on
137
+ reinforcement learning to select training samples adaptively. However, it is common for deep learning
138
+ to emphasize hard samples because of the plethora of easy ones (Katharopoulos & Fleuret, 2018).
139
+
140
+ Loshchilov & Hutter (2016) proposed a difficulty-based sampling scheme, called Online Batch,
141
+ that uses the rank of the loss computed from previous epochs. Online Batch sorts the previously
142
+ computed losses of samples in descending order and exponentially decays the sampling probability
143
+ of a sample according to its rank r. Then, the r-th ranked sample x(r) is selected with the probability
144
+ dropping by a factor of exp log(se)/N, where se is the selection pressure parameter that affects
145
+ the probability gap between the most and the least important samples. When normalized to sum
146
+ to 1.0, the probability P (x(r) ; se) is defined by Eq. (1). It has been reported that _Online Batch_
147
+ _|D_
148
+
149
+
150
+ -----
151
+
152
+ accelerates the convergence of training but deteriorates the generalization error because of the
153
+ overfitting to hard training samples (Loshchilov & Hutter, 2016).
154
+
155
+ _r_
156
+ 1/ exp log(se)/N
157
+ _P_ (x(r) ; se) = _N_ _j_ (1)
158
+ _|D_ _j=1_ [1][/][ exp] log(se)/N
159
+
160
+ Most close to our work, Chang et al. (2017) devised anP _uncertainty_ -based sampling scheme, called
161
+ _Active Bias, that chooses uncertain samples with high probability for the next batch. Active Bias_
162
+ maintains the history _i_ that stores all h(yi _xi) before the current iteration t (i.e., growing window),_
163
+ _H[t][−][1]_ _|_
164
+ where h(yi|xi) is the softmax probability of a given sample xi for its true label yi. Then, it measures
165
+ the uncertainty of the sample xi by computing the variance over all h(yi _xi) in_ _i_ and draws the
166
+ _|_ _H[t][−][1]_
167
+ next mini-batch samples based on the normalized probability P (xi _,_ _i_ ; ϵ) in Eq. (2), where ϵ is
168
+ _|D_ _H[t][−][1]_
169
+ the smoothness constant to prevent the low variance samples from never being selected again. As
170
+ mentioned earlier in Section 1, Active Bias slows down the training process because the oldest part in
171
+ the history _i_ no longer represents the current behavior of the network.
172
+ _H[t][−][1]_
173
+
174
+ _P_ (xi|D, Hi[t][−][1]; ϵ) = _Nj=1stdˆ_ ˆstdi(Hji[t]([−][1]j) +) + ϵ _ϵ_ _,_ _stdˆ_ (Hi[t][−][1]) = vuvar _h(yi|xi)_ + _[var]h(iyi|xi)2_
175
+ _H[t][−][1]_ ut  _|H[t][−][1]|_ (2)
176
+ P 
177
+
178
+ For the completeness of the survey, we include the recent studies on submodular batch selection.
179
+ Joseph et al. (2019) and Wang et al. (2019) designed their own submodular objectives that cover
180
+ diverse aspects, such as sample redundancy and sample representativeness, for more effective
181
+ batch selection. Differently from their work, we explore the issue of truly uncertain samples in
182
+ an orthogonal perspective. Our uncertainty measure can be easily injected into their submodular
183
+ optimization framework as a measure of sample informativeness.
184
+
185
+ In Section 5, we will confirm that Recency Bias outperforms Online Batch and Active Bias, which are
186
+ regarded as two state-of-the-art adaptive batch selection methods for deep learning.
187
+
188
+ 3 _Recency Bias COMPONENTS_
189
+
190
+ 3.1 CRITERION OF AN UNCERTAIN SAMPLE
191
+
192
+ The main challenge of Recency Bias is to identify the samples whose recent label predictions are
193
+ highly inconsistent, which are neither too easy nor too hard at the moment. Thus, we adopt the
194
+ _predictive uncertainty (Song et al., 2019) in Definition 3.1 that uses the information entropy (Chandler,_
195
+ 1987) to measure the inconsistency of recent label predictions. Here, the sample with high predictive
196
+ uncertainty is regarded as uncertain and selected with high probability for the next mini-batch.
197
+ **Definition 3.1. (Predictive Uncertainty) Let ˆyt = Φ(xi, θt) be the predicted label of a sample xi at**
198
+ time t and Hxi (q) = {yˆt1 _, ˆyt2_ _, . . ., ˆytq_ _} be the label history of the sample xi that stores the predicted_
199
+ labels at the previous q times, where Φ is a neural network. The label history _xi_ (q) corresponds
200
+ _H_
201
+ to the sliding window of size q to compute the uncertainty of the sample xi. Next, p(yi _xi; q) is_
202
+ _|_
203
+ formulated such that it provides the probability of the label y ∈{1, 2, ..., k} estimated as the label of
204
+ the sample xi based on Hxi (q) as in Eq. (3), where [·] is the Iverson bracket[1].
205
+
206
+ _p(y_ _xi; q) =_ _yˆ∈Hxi_ (q)[[ˆ]y = y] (3)
207
+ _|_ P _xi_ (q)
208
+
209
+ _|H_ _|_
210
+
211
+ Then, to quantify the uncertainty of the sample xi, the predictive uncertainty F (xi; q) is defined by
212
+ Eq. (4), where δ is the standardization term to normalize the value to [0, 1].
213
+
214
+
215
+ _F_ (xi; q) = (1/δ)
216
+ _−_
217
+
218
+
219
+ _p(j_ _xi; q) log p(j_ _xi; q)_
220
+ _|_ _|_
221
+ _j=1_
222
+
223
+ X
224
+
225
+
226
+ (4)
227
+
228
+
229
+ _δ = −_ log (1/k) □
230
+
231
+ 1The Iverson bracket [p] returns 1 if p is true; 0 otherwise.
232
+
233
+
234
+ -----
235
+
236
+ 3.2 SAMPLING PROBABILITY FOR MINI-BATCH CONSTRUCTION
237
+
238
+ To construct next mini-batch samples, we assign the sampling probability according to the predictive
239
+ uncertainty in Definition 3.1. Motivated by Loshchilov & Hutter (2016), the sampling probability
240
+ of a given sample xi is exponentially decayed with its predictive uncertainty F (xi; q). In detail,
241
+ we adopt the quantization method (Chen & Wornell, 2001) and use the quantization index to decay
242
+ the sampling probability. The index is obtained by the simple quantizer Q in Eq. (5), where ∆ is
243
+ the quantization step size. Compared with the rank-based index (Loshchilov & Hutter, 2016), the
244
+ quantization index is known to well reflect the difference in actual values (Widrow et al., 1996).
245
+
246
+ _Q_ _F_ (xi; q) = 1 _F_ (xi; q) _/∆_ _, 0_ _F_ (xi; q) 1 (5)
247
+ _⌈_ _−_ _⌉_ _≤_ _≤_
248
+
249
+ In Eq. (5), we set ∆ to be 1/N such that the index is bounded to  _N (the total number of samples)._
250
+ Then, the sampling probability P (xi ; se) is defined as in Eq. (6). The higher the predictive
251
+ _|D_
252
+ uncertainty, the smaller the quantization index. Therefore, a higher sampling probability is assigned
253
+ for uncertain samples in Eq. (6).
254
+
255
+ 1/ exp log(se)/N _Q(F (xi;q))_
256
+ _P_ (xi|D; se) = _N_ _Q(F (xj_ ;q)) (6)
257
+
258
+ _j=1_ [1][/][ exp] log(se)/N
259
+
260
+ Meanwhile, it is known that using only some part of training data exacerbates the overfitting problemP 
261
+ at a late stage of training (Loshchilov & Hutter, 2016; Zhou & Bilmes, 2018). Thus, to alleviate
262
+ the problem, we include more training samples as the training progresses by exponentially decaying
263
+ the selection pressure se as in Eq. (7). At each epoch e from e0 to eend, the selection pressure
264
+ _se exponentially decreases from se0 to 1. Because this technique gradually reduces the sampling_
265
+ probability gap between the most and the least uncertain samples, more diverse samples are selected
266
+ for the next mini-batch at a later epoch. When the selection pressure se becomes 1, the mini-batch
267
+ samples are randomly chosen from the entire dataset.
268
+
269
+ 0
270
+ _se = se0_ exp log (1/se0 )/(eend − _e0)_ (7)
271
+  [][e][−][e]
272
+
273
+ 4 _Recency Bias ALGORITHM_
274
+
275
+ **Algorithm 1 Recency Bias Algorithm**
276
+
277
+ INPUT: : data, epochs, b: batch size, q: window size, se0 : initial selection pressure, γ: warm-up
278
+ _D_
279
+ OUTPUT: θt: model parameter
280
+
281
+ 1: t ← 1;
282
+ 2: θt ← Initialize the model parameter;
283
+ 3: for i = 1 to epochs do
284
+ 4: /* Sampling Probability Derivation */
285
+
286
+ 5: **if i > γ then**
287
+
288
+ 6: _se ←_ Decay_Selection_Pressure(se0, i); /* Decaying se by Eq. (7) */
289
+
290
+ 7: **for m = 1 to N do** /* Updating the index and the sampling probability in a batch */
291
+
292
+ 8: _q_dict[xm] = Q_ _F_ (xm; q) ; /* By Eq. (5) */
293
+
294
+
295
+ 9: _p_table ←_ Compute_Prob(q_dict, se); /* By Eq. (6) */
296
+
297
+ 10: /* Network Training */
298
+
299
+ 11: **for j = 1 to N/b do** /* Mini-batch */
300
+
301
+ 12: **if i ≤** _γ then_ /* Warm-up */
302
+
303
+ 13: (x1, y1), . . ., (xb, yb) Randomly select next mini-batch samples;
304
+ _{_ _} ←_
305
+
306
+ 14: **else /* Adaptive batch selection */**
307
+
308
+ 15: (x1, y1), . . ., (xb, yb) Select next mini-batch samples based on p_table;
309
+ _{_ _} ���_
310
+
311
+ 16: _losses, labels_ Inference_Step( (x1, y1), . . ., (xb, yb),θt); /* Forward */
312
+ _←_ _{_ _}_
313
+
314
+ 17: _θt+1 ←_ SGD_Step(losses, θt); /* Backward */
315
+
316
+ 18: Update_Label_History(labels); /* By Definition 3.1 */
317
+
318
+ 19: _t ←_ _t + 1;_
319
+
320
+ 20: return θt;
321
+
322
+ Algorithm 1 describes the overall procedure of Recency Bias. The algorithm requires a warm-up
323
+ period of γ epochs because the quantization index for each sample is not confirmed yet. During
324
+ the warm-up period, which should be at least q epochs (γ ≥ _q) to obtain the label history of size_
325
+
326
+
327
+ -----
328
+
329
+ _q, randomly selected mini-batch samples are used for the network update (Lines 12–13). After the_
330
+ warm-up period, the algorithm decays the selection pressure se and updates not only the quantization
331
+ index but also the sampling probability in a batch at the beginning of each epoch (Lines 4–9).
332
+ Subsequently, the uncertain samples are selected for the next mini-batch according to the updated
333
+ sampling probability (Line 14–15), and then the label history is updated along with the network
334
+ update (Lines 16–19).
335
+
336
+ Overall, the key technical novelty of Recency Bias is to incorporate the notion of a sliding win_dow (Line 8) rather than a growing window into adaptive batch selection, thereby improving both_
337
+ training speed and generalization error.
338
+
339
+ **Time Complexity: The main “additional” cost of Recency Bias is the derivation of the sampling**
340
+ probability for each sample (Lines 4–9). Because only simple mathematical operations are needed
341
+ per sample, its time complexity is linear to the number of samples (i.e., O(N )), which is negligible
342
+ compared with that of the forward and backward steps of a complex network (Lines 16–17). Therefore,
343
+ we contend that Recency Bias does not add the complexity of an underlying optimization algorithm.
344
+
345
+ 5 EVALUATION
346
+
347
+ We empirically show the improvement of Recency Bias over not only Random Batch (baseline) but also
348
+ _Online Batch (Loshchilov & Hutter, 2016) and Active Bias (Chang et al., 2017), which are two state-_
349
+ of-the-art adaptive batch selections. In particular, we elaborate on the effect of the sliding window
350
+ approach (Recency Bias) compared with the growing window approach (Active Bias). Random Batch
351
+ selects next mini-batch samples uniformly at random from the entire dataset. Online Batch selects hard
352
+ samples based on the rank of the loss computed from previous epochs. Active Bias selects uncertain
353
+ samples with high variance of true label probabilities in the growing window. All the algorithms
354
+ were implemented using TensorFlow 1.8.0 and executed using a single NVIDIA Titan Volta GPU.
355
+ [For reproducibility, we provide the source code at https://github.com/anonymized.](https://github.com/anonymized)
356
+
357
+ Image classification and fine-tuning tasks were performed to validate the superiority of Recency Bias.
358
+ Because fine-tuning is used to quickly adapt to a new dataset, it is suitable to reap the benefit of fast
359
+ training speed. In support of reliable evaluation, we repeated every task thrice and reported the average
360
+ and standard error of the best test errors. The best test error in a given time has been widely used for
361
+ the studies on fast and accurate training (Katharopoulos & Fleuret, 2018; Loshchilov & Hutter, 2016).
362
+
363
+ 5.1 ANALYSIS ON SELECTED MINI-BATCH SAMPLES
364
+
365
+ For an in-depth analysis on selected samples, we plot the loss distribution of mini-batch samples
366
+ selected from CIFAR-10 by four different strategies in Figure 3. (i) The distribution of Online Batch
367
+ is the most skewed toward high loss by the design principle of selecting hard samples. (ii) Active Bias
368
+ emphasizes moderately hard samples at an early training stage in considering that its loss distribution
369
+ lies between those of Random Batch and Online Batch. However, owing to the outdated predictions
370
+ caused by the growing window, the proportion of easy samples with low loss increases at a late
371
+ training stage. These easy samples, which are misclassified as uncertain at that stage, tend to make the
372
+ convergence of training slow down. (iii) In contrast to Active Bias, by virtue of the sliding window,
373
+ the distribution of Recency Bias lies between those of Random Batch and Online Batch regardless of
374
+ the training stage. Consequently, Recency Bias continues to highlight the moderately hard samples,
375
+ which are likely to be informative, during the training process.
376
+
377
+ Random Batch
378
+ Online Batch
379
+ Active Bias
380
+ (Growing window)
381
+ Recency Bias
382
+ (Sliding window)
383
+
384
+ Loss (Log-scale) Loss (Log-scale)
385
+
386
+
387
+ (a) Early Stage (30%). (b) Late Stage (70%).
388
+
389
+ Figure 3: The loss distribution of mini-batch samples selected by four batch selection strategies: (a)
390
+ and (b) show the loss distribution at the 30% and 70% of total training epochs, respectively.
391
+
392
+
393
+ -----
394
+
395
+ 5.2 TASK I: IMAGE CLASSIFICATION
396
+
397
+ **Experiment Setting: We trained DenseNet (L=40, k=12) and ResNet (L=50) with a momentum**
398
+ optimizer and an SGD optimizer on three benchmark datasets: MNIST (10 classes)[2], classification
399
+ of handwritten digits (LeCun, 1998), and CIFAR-10 (10 classes)[3] and CIFAR-100 (100 classes)[3],
400
+ classification of a subset of 80 million categorical images (Krizhevsky et al., 2014). Specifically, we
401
+ used data augmentation, batch normalization, a momentum of 0.9, and a batch size of 128. As for the
402
+ algorithm parameters, we fixed the window size q = 10 and the initial selection pressure se0 = 100,[4]
403
+ which were the best values found by the grid search (see Appendix A for details). The warm-up
404
+ epoch γ was set to be 15. To reduce the performance variance caused by randomly initialized model
405
+ parameters, all parameters were shared by all algorithms during the warm-up period. Regarding
406
+ the training schedule, we trained the network for 40, 000 iterations and used an initial learning rate
407
+ of 0.1, which was divided by 10 at 50% and 75% of the total number of training iterations.
408
+
409
+ **Results: Figure 4 shows the convergence curves of training loss and test error for four batch selection**
410
+ strategies using DenseNet and a momentum optimizer. In order to highlight the improvement of
411
+ _Recency Bias over the baseline (Random Batch), their lines are dark colored. The best test errors in_
412
+ Figures 4(b), 4(d), and 4(f) are summarized on the left side of Table 1.
413
+
414
+ In general, Recency Bias achieved the most accurate network while accelerating the training process
415
+ on all datasets. The training loss of Recency Bias converged faster (Figures 4(a), 4(c), and 4(e))
416
+ without the increase in the generalization error, thereby achieving the lower test error (Figures 4(b),
417
+ 4(d), and 4(f)). In contrast, the test error of Online Batch was not the best even if its training loss
418
+ converged the fastest among all strategies. As the training difficulty increased from CIFAR-10 to
419
+ CIFAR-100, the test error of Online Batch became even worse than that of Random Batch. That
420
+ is, emphasizing hard samples accelerated the training step but made the network overfit to hard
421
+ samples. Meanwhile, Active Bias was prone to make the network better generalized on test data.
422
+ In CIFAR-10, despite its highest training loss, the test error of Active Bias was better than that of
423
+ _Random Batch. However, Active Bias slowed down the training process because of the limitation_
424
+ of growing windows, as discussed in Section 5.1. We note that, although both Recency Bias and
425
+ _Active Bias exploited uncertain samples, only Recency Bias based on sliding windows succeeded_
426
+ to not only speed up the training process but also reduce the generalization error.
427
+
428
+ The results of the best test error for ResNet or an SGD optimizer are summarized in Tables 1 and
429
+ 2 (see Appendix B for more details). Regardless of a neural network and an optimizer, Recency
430
+ _Bias achieved the lowest test error except in MNIST with an SGD optimizer. The improvement of_
431
+ _Recency Bias over the others was higher with an SGD optimizer than with a momentum optimizer._
432
+
433
+ Table 1: The best test errors (%) of four batch selection strategies using DenseNet.
434
+
435
+ |Optimizer|Momentum in Figure 4|Col3|Col4|SGD in Figure 7 (Appendix B.1)|Col6|Col7|
436
+ |---|---|---|---|---|---|---|
437
+ |Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
438
+ |Random Batch|0.527 ± 0.03|7.33 ± 0.09|28.0 ± 0.16|1.23 ± 0.03|14.9 ± 0.09|40.2 ± 0.06|
439
+ |Online Batch|0.514 ± 0.01|7.00 ± 0.10|28.4 ± 0.25|0.765 ± 0.02|13.5 ± 0.02|40.7 ± 0.12|
440
+ |Active Bias|0.616 ± 0.03|7.07 ± 0.04|27.9 ± 0.11|0.679 ± 0.02|14.2 ± 0.25|42.9 ± 0.05|
441
+ |Recency Bias|0.490 ± 0.02|6.60 ± 0.02|27.1 ± 0.19|0.986 ± 0.06|13.2 ± 0.11|38.7 ± 0.11|
442
+
443
+
444
+
445
+ Table 2: The best test errors (%) of four batch selection strategies using ResNet.
446
+
447
+ |Optimizer|Momentum in Figure 8 (Appendix B.2)|Col3|Col4|SGD in Figure 9 (Appendix B.3)|Col6|Col7|
448
+ |---|---|---|---|---|---|---|
449
+ |Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
450
+ |Random Batch|0.636 ± 0.04|10.2 ± 0.12|33.2 ± 0.07|1.16 ± 0.03|12.7 ± 0.09|40.1 ± 0.16|
451
+ |Online Batch|0.666 ± 0.05|10.1 ± 0.05|33.4 ± 0.01|0.890 ± 0.03|12.2 ± 0.08|40.7 ± 0.09|
452
+ |Active Bias|0.613 ± 0.04|10.6 ± 0.08|34.2 ± 0.07|0.804 ± 0.01|13.5 ± 0.07|45.6 ± 0.07|
453
+ |Recency Bias|0.607 ± 0.01|9.79 ± 0.04|32.4 ± 0.04|0.972 ± 0.03|11.6 ± 0.09|38.9 ± 0.14|
454
+
455
+
456
+
457
+ [2http://yann.lecun.com/exdb/mnist](http://yann.lecun.com/exdb/mnist)
458
+ [3https://www.cs.toronto.edu/~kriz/cifar.html](https://www.cs.toronto.edu/~kriz/cifar.html)
459
+ 4Online Batch also used the same decaying selection pressure value.
460
+
461
+
462
+ -----
463
+
464
+ |Col1|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|
465
+ |---|---|---|---|---|---|---|---|---|---|
466
+ ||Random Batch Online||||Batch Active Bias Recency Bias|||||
467
+ |E-01 E-02 E-03|||||3.6% Error 1.2% Test|||||
468
+ |||||||||||
469
+ |||||||||||
470
+ |||||||||||
471
+ |||||||||||
472
+
473
+
474
+ 2125 4250 6375 8500
475
+
476
+ Time (s)
477
+
478
+
479
+ 2125 4250 6375 8500
480
+
481
+ Time (s)
482
+
483
+
484
+ 0.90.80.70.60.50.40.30.20.110
485
+
486
+
487
+
488
+
489
+ |Col1|Col2|Col3|Col4|
490
+ |---|---|---|---|
491
+ |||||
492
+ |||||
493
+
494
+ |0|Col2|Col3|Col4|
495
+ |---|---|---|---|
496
+ |||||
497
+ |||||
498
+
499
+
500
+ |(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
501
+ |---|---|---|---|---|
502
+ |(a) MNIST Training Loss. (b) MNIST Test Error. 16E-01 20 40 60 80 26.0%10 0 4E-01 Error 13.0% Test 4E-02 0E-03 6.5% 0 2500 5000 7500 10000 0 2500 5000 7500 100 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 4E+00 54.0% Error 0E-01 Test|||||
503
+ ||||||
504
+
505
+
506
+ 2500 5000 7500 10000
507
+
508
+
509
+ 2.4E+00
510
+
511
+ 6.0E-01
512
+
513
+
514
+ 1.5E-01
515
+
516
+
517
+ 27.0%
518
+
519
+
520
+ 2500 5000 7500 10000
521
+
522
+ Time (s)
523
+
524
+
525
+ Time (s)
526
+
527
+
528
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
529
+
530
+ Figure 4: Convergence curves of four batch selection strategies using DenseNet with momentum.
531
+
532
+
533
+ 5.3 TASK II: FINE-TUNING
534
+
535
+ **Experiment Setting: We prepared DenseNet (L=121, k=32) previously trained on ImageNet (Deng**
536
+ et al., 2009) and then fine-tuned the network on two benchmark datasets: MIT-67 (67 classes)[5],
537
+ classification of indoor scenes (Quattoni & Torralba, 2009), and Food-100 (100 classes)[6], classification of popular foods in Japan (Kawano & Yanai, 2014). After replacing the last classification
538
+ layer, the network was trained end-to-end for 50 epochs with a batch size 32 and a constant learning
539
+ rate 2 × 10[−][4]. Data augmentation was not applied here. The other configurations were the same
540
+ as those in Section 5.2.
541
+
542
+ **Results on Test Error: Figure 5 shows the convergence curves of training loss and test error for**
543
+ the fine-tuning task on MIT-67 and Food-100. Overall, all convergence curves showed similar trends
544
+ to those of the classification task in Figure 4. Only Recency Bias converged faster than Random
545
+ _Batch in both training loss and test error. Online Batch converged the fastest in training loss, but_
546
+ its test error was rather higher than Random Batch owing to the overfitting. Active Bias converged the
547
+
548
+
549
+ [5http://web.mit.edu/torralba/www/indoor.html](http://web.mit.edu/torralba/www/indoor.html)
550
+ [6http://foodcam.mobi/dataset100.html](http://foodcam.mobi/dataset100.html)
551
+
552
+
553
+ -----
554
+
555
+ |Col1|Col2|Col3|Col4|Col5|
556
+ |---|---|---|---|---|
557
+ ||||||
558
+ |Time Redu|ction: 24.6|%|||
559
+
560
+
561
+ Random Batch Online Batch Active Bias Recency Bias
562
+
563
+ 1.9E+00 39.0%
564
+
565
+ 6.3E-012.1E-01 Test Error 35.0%31.0%
566
+
567
+ Training Loss
568
+
569
+ Time Reduction: 24.6%
570
+
571
+ 7.0E-02 27.0%
572
+
573
+ 0 1500 3000 4500 6000 0 1500 3000 4500 6000
574
+
575
+ 0.90.80.70.60.50.40.30.20.110 (a) MIT-67 Training Loss.Time (s) (b) MIT-67 Test Error.Time (s)
576
+
577
+ 1.6E+001 20 40 60 80 44.0%10
578
+
579
+ 0
580
+
581
+ 40.0%
582
+
583
+
584
+ 0.90.80.70.60.50.40.30.20.110
585
+
586
+
587
+ 8.0E-01
588
+
589
+ 4.0E-01
590
+
591
+
592
+ 36.0%
593
+
594
+ 32.0%
595
+
596
+
597
+ 2.0E-01
598
+
599
+ |20|4|0|60|Col5|
600
+ |---|---|---|---|---|
601
+ ||||||
602
+ ||||||
603
+
604
+ |0|Col2|Col3|Col4|Col5|
605
+ |---|---|---|---|---|
606
+ |0|||||
607
+ ||||||
608
+ |Time Redu|ction: 26.1|%|||
609
+
610
+
611
+ 2000 4000 6000 8000
612
+
613
+ Time (s)
614
+
615
+
616
+ 2000 4000 6000 8000
617
+
618
+ Time (s)
619
+
620
+
621
+ (c) Food-100 Training Loss. (d) Food-100 Test Error.
622
+
623
+ Figure 5: Convergence curves for fine-tuning on two benchmark datasets.
624
+
625
+
626
+ Table 3: Recency Bias’s reduction in training time over other batch selection strategies.
627
+
628
+
629
+ |Method|MIT-67|FOOD-100|
630
+ |---|---|---|
631
+ |Random Batch|(5, 218 −3, 936)/5, 218 × 100 = 24.6%|(7, 263 −5, 365)/7, 263 × 100 = 26.1%|
632
+ |Online Batch|(6, 079 −3, 823)/6, 079 × 100 = 37.1%|(8, 333 −3, 685)/8, 333 × 100 = 55.8%|
633
+ |Active Bias|(5, 738 −3, 032)/5, 738 × 100 = 47.2%|(7, 933 −3, 227)/7, 933 × 100 = 59.3%|
634
+
635
+
636
+ slowest in both training loss and test error. Quantitatively, compared with Random Batch, Recency
637
+ _Bias reduced the test error by 2.88% and 1.81% in MIT-67 and Food-100, respectively._
638
+
639
+ **Results on Training Time: Moreover, to assess the performance gain in training time, we computed**
640
+ the reduction in the training time taken to reach the same error. For example, in Figure 5(b), the
641
+ best test error of 28.8% achieved in 5, 218 seconds by Random Batch could be achieved only in
642
+ 3, 936 seconds by Recency Bias; thus, Recency Bias improved the training time by 24.6%. Table
643
+ 3 summarizes the reduction in the training time of Recency Bias over three other batch selection
644
+ strategies. Notably, Recency Bias improved the training time by 24.6%–47.2% and 26.1%–59.3% in
645
+ fine-tuning MIT-67 and FOOD-100 datasets, respectively.
646
+
647
+ 6 CONCLUSION
648
+
649
+
650
+ In this paper, we presented a novel adaptive batch selection algorithm called Recency Bias that
651
+ emphasizes predictively uncertain samples for accelerating the training of neural networks. Toward
652
+ this goal, the predictive uncertainty of each sample is evaluated using its recent label predictions
653
+ managed by a sliding window of a fixed size. Then, uncertain samples at the moment are selected with
654
+ high probability for the next mini-batch. We conducted extensive experiments on both classification
655
+ and fine-tuning tasks. The results showed that Recency Bias is effective in reducing the training
656
+ time as well as the best test error. It was worthwhile to note that using all historical observations to
657
+ estimate the uncertainty has the side effect of slowing down the training process. Overall, a merger of
658
+ uncertain samples and sliding windows greatly improves the power of adaptive batch selection.
659
+
660
+
661
+ -----
662
+
663
+ REFERENCES
664
+
665
+ Yoshua Bengio, Jérôme Louradour, Ronan Collobert, and Jason Weston. Curriculum learning. In
666
+ _ICML, pp. 41–48, 2009._
667
+
668
+ David Chandler. Introduction to modern statistical mechanics. Oxford University Press, 1987.
669
+
670
+ Haw-Shiuan Chang, Erik Learned-Miller, and Andrew McCallum. Active Bias: Training more
671
+ accurate neural networks by emphasizing high variance samples. In NeurIPS, pp. 1002–1012,
672
+ 2017.
673
+
674
+ Brian Chen and Gregory W Wornell. Quantization index modulation: A class of provably good
675
+ methods for digital watermarking and information embedding. IEEE Trans. on Information Theory,
676
+ 47(4):1423–1443, 2001.
677
+
678
+ Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale
679
+ hierarchical image database. In CVPR, pp. 248–255, 2009.
680
+
681
+ Yang Fan, Fei Tian, Tao Qin, and Tie-Yan Liu. Neural data filter for bootstrapping stochastic gradient
682
+ descent. In ICLR, 2017.
683
+
684
+ Bo Han, Quanming Yao, Xingrui Yu, Gang Niu, Miao Xu, Weihua Hu, Ivor Tsang, and Masashi
685
+ Sugiyama. Co-teaching: Robust training of deep neural networks with extremely noisy labels. In
686
+ _NeurIPS, pp. 8527–8537, 2018._
687
+
688
+ KJ Joseph, Krishnakant Singh, Vineeth N Balasubramanian, et al. Submodular batch selection for
689
+ training deep neural networks. In IJCAI, pp. 2677–3683, 2019.
690
+
691
+ Angelos Katharopoulos and François Fleuret. Not all samples are created equal: Deep learning with
692
+ importance sampling. In ICML, pp. 2525–2534, 2018.
693
+
694
+ Y. Kawano and K. Yanai. Food image recognition with deep convolutional features. In UbiComp,
695
+ 2014.
696
+
697
+ Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. CIFAR-10 and CIFAR-100 datasets, 2014.
698
+ [https://www.cs.toronto.edu/~kriz/cifar.html.](https://www.cs.toronto.edu/~kriz/cifar.html)
699
+
700
+ M Pawan Kumar, Benjamin Packer, and Daphne Koller. Self-paced learning for latent variable
701
+ models. In NeurIPS, pp. 1189–1197, 2010.
702
+
703
+ [Yann LeCun. The MNIST database of handwritten digits, 1998. http://yann.lecun.com/](http://yann.lecun.com/exdb/mnist)
704
+ [exdb/mnist.](http://yann.lecun.com/exdb/mnist)
705
+
706
+ Ilya Loshchilov and Frank Hutter. Online batch selection for faster training of neural networks. In
707
+ _ICLR, 2016._
708
+
709
+ Ariadna Quattoni and Antonio Torralba. Recognizing indoor scenes. In CVPR, pp. 413–420, 2009.
710
+
711
+ Mrinmaya Sachan and Eric Xing. Easy questions first? A case study on curriculum learning for
712
+ question answering. In ACL, pp. 453–463, 2016.
713
+
714
+ Abhinav Shrivastava, Abhinav Gupta, and Ross Girshick. Training region-based object detectors
715
+ with online hard example mining. In CVPR, pp. 761–769, 2016.
716
+
717
+ Hwanjun Song, Minseok Kim, and Jae-Gil Lee. SELFIE: Refurbishing unclean samples for robust
718
+ deep learning. In ICML, pp. 5907–5915, 2019.
719
+
720
+ Luis Torgo. Data mining with R: learning with case studies. Chapman and Hall/CRC, 2011.
721
+
722
+ Yulia Tsvetkov, Manaal Faruqui, Wang Ling, Brian MacWhinney, and Chris Dyer. Learning the
723
+ curriculum with bayesian optimization for task-specific word representation learning. In ACL, pp.
724
+ 130–139, 2016.
725
+
726
+ Shengjie Wang, Wenruo Bai, Chandrashekhar Lavania, and Jeff Bilmes. Fixing mini-batch sequences
727
+ with hierarchical robust partitioning. In AISTATS, pp. 3352–3361, 2019.
728
+
729
+
730
+ -----
731
+
732
+ Bernard Widrow, Istvan Kollar, and Ming-Chang Liu. Statistical theory of quantization. IEEE
733
+ _Transactions on instrumentation and measurement, 45(2):353–361, 1996._
734
+
735
+ Tianyi Zhou and Jeff Bilmes. Minimax curriculum learning: Machine teaching with desirable
736
+ difficulties and scheduled diversity. In ICLR, 2018.
737
+
738
+
739
+ -----
740
+
741
+ A HYPERPARAMETER SELECTION
742
+
743
+ _Recency Bias receives the two hyperparameters: (i) the initial selection pressure se0 that determines_
744
+ the sampling probability gap between the most and the least uncertain samples and (ii) the window
745
+ size q that determines how many recent label predictions are involved in predicting the uncertainty.
746
+ To decide the best hyperparameters, we trained ResNet (L=50) on CIFAR-10 and CIFAR-100 with a
747
+ momentum optimizer. For hyperparameters selection, the two hyperparameters were chosen in a grid
748
+ _se0_ 1, 10, 100, 1000 and q 5, 10, 15 .
749
+ _∈{_ _}_ _∈{_ _}_
750
+
751
+
752
+
753
+ |Window|Size|Col3|
754
+ |---|---|---|
755
+ |q=5 q=10 q=15|||
756
+
757
+
758
+ 10.4% 33.5%
759
+
760
+ Window Size
761
+
762
+ 10.1% 33.0% q=5
763
+
764
+ q=10
765
+
766
+ 9.8% 32.5%
767
+
768
+ Best Test Error q=15
769
+
770
+ 9.5% 32.0%
771
+
772
+ 1 10 100 1000 1 10 100 1000
773
+
774
+ Initial Selection Pressure (𝑆𝑒0) Initial Selection Pressure (𝑆𝑒0)
775
+
776
+
777
+ (a) CIFAR-10. (b) CIFAR-100.
778
+
779
+ Figure 6: Grid search on CIFAR-10 and CIFAR-100 datasets using ResNet.
780
+
781
+ Figure 6 shows the test errors of Recency Bias obtained by the grid search on the two datasets.
782
+ Regarding the initial selection pressure se0, the lowest test error was typically achieved when the
783
+ _se0 value was 100. As for the window size q, the test error was almost always the lowest when the q_
784
+ value was 10. Similar trends were observed for the other combinations of a neural network and an
785
+ optimizer. Therefore, in all experiments, we set se0 to be 100 and q to be 10.
786
+
787
+
788
+ -----
789
+
790
+ |GENERALIZATION OF Recency Bias|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|Col11|Col12|
791
+ |---|---|---|---|---|---|---|---|---|---|---|---|
792
+ |CONVERGENCE CURVES USING DENSENET WITH SGD 7 shows the convergence curves of training loss and test error for four batch selection strate DenseNet and an SGD optimizer, which corresponds to the right side of Table 1.||||||||||||
793
+ ||eNet and an SGD optimizer, whic|||||||||||
794
+ ||Random Batch Online|||||Batch Active Bias Recency Bias||||||
795
+ |E-01 E-02 E-02||||||4.8% 2.4% Error Test 1.2%||||||
796
+ |||||||||||||
797
+ |||||||||||||
798
+ |||||||||||||
799
+
800
+
801
+ 2000 4000 6000 8000
802
+
803
+ Time (s)
804
+
805
+
806
+ 2000 4000 6000 8000
807
+
808
+ Time (s)
809
+
810
+
811
+ 0.90.80.70.60.50.40.30.20.110
812
+
813
+
814
+
815
+
816
+ |(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
817
+ |---|---|---|---|---|
818
+ |(a) MNIST Training Loss. (b) MNIST Test Error. 10E+00 20 40 60 80 48.0%10 0 2E-01 Error 24.0% Test 6E-01|||||
819
+ ||||||
820
+ ||||||
821
+
822
+
823
+ 3.2E+00
824
+
825
+
826
+ 12.0%
827
+
828
+ |Col1|Col2|Col3|Col4|
829
+ |---|---|---|---|
830
+ |||||
831
+ |||||
832
+
833
+
834
+ 2500 5000 7500 10000
835
+
836
+ Time (s)
837
+
838
+
839
+ 2500 5000 7500 10000
840
+
841
+
842
+ (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
843
+
844
+ 70.0%
845
+
846
+
847
+ 1.6E+00
848
+
849
+ 8.0E-01
850
+
851
+
852
+ 35.0%
853
+
854
+ |(c) CIFAR-10 Training Loss.|Col2|Col3|Col4|
855
+ |---|---|---|---|
856
+ |||||
857
+ |||||
858
+ |||||
859
+
860
+ |Time (s)|Col2|Col3|Col4|
861
+ |---|---|---|---|
862
+ |(d) CIFAR-10 Test Error.||||
863
+ |||||
864
+
865
+
866
+ 2500 5000 7500 10000
867
+
868
+ Time (s)
869
+
870
+
871
+ 2500 5000 7500 10000
872
+
873
+ Time (s)
874
+
875
+
876
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
877
+
878
+ Figure 7: Convergence curves of four batch selection strategies using DenseNet with SGD.
879
+
880
+
881
+ -----
882
+
883
+ |Col1|et and a momentum optimizer, w|Col3|Col4|Col5|
884
+ |---|---|---|---|---|
885
+ ||Random Batch Online|||Batch Active Bias Recency Bias|
886
+ ||||||
887
+ ||||||
888
+ ||||||
889
+ ||||||
890
+
891
+
892
+
893
+ 0.90.80.70.60.50.40.30.20.110
894
+
895
+
896
+
897
+
898
+ |20|40|60|
899
+ |---|---|---|
900
+ ||||
901
+ ||||
902
+ ||||
903
+
904
+ |0|Col2|Col3|
905
+ |---|---|---|
906
+ ||||
907
+ ||||
908
+
909
+
910
+
911
+
912
+ |CONVERGENCE CURVES USING RESNET WITH MOMENTUM e 8 shows the convergence curves of training loss and test error for four batch selection strate ResNet and a momentum optimizer, which corresponds to the left side of Table 2. Random Batch Online Batch Active Bias Recency Bias 5E-01 4.5% 9E-02 Error 1.5% Test 4E-03 0E-04 0.5% 0 2300 4600 6900 0 2300 4600 69 Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 17E+00 20 40 60 80 36.0%10 0 5E-01 Error 18.0% Test 5E-02 0E-03 9.0% 0 2900 5800 8700 0 2900 5800 870 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 3E+00 64.0% 2E-01 Error Test 2E-01|Col2|Col3|Col4|
913
+ |---|---|---|---|
914
+ |||||
915
+
916
+
917
+ 2900 5800 8700
918
+
919
+ Time (s)
920
+
921
+
922
+ 2900 5800 8700
923
+
924
+ Time (s)
925
+
926
+
927
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
928
+
929
+ Figure 8: Convergence curves of four batch selection strategies using ResNet with momentum.
930
+
931
+
932
+ -----
933
+
934
+ |CONVERGENCE CURVES USING RESNET WITH SGD 9 shows the convergence curves of training loss and test error for four batch selection strate ResNet and an SGD optimizer, which corresponds to the right side of Table 2.|Col2|Col3|Col4|Col5|Col6|Col7|Col8|
935
+ |---|---|---|---|---|---|---|---|
936
+ ||ows the convergence curves of trai et and an SGD optimizer, which|||||||
937
+ ||Random Batch Online|||Batch Active Bias Recency Bias||||
938
+ |E-01 E-02 E-02||||6.3% Error 2.1% Test||||
939
+ |||||||||
940
+ |||||||||
941
+ |||||||||
942
+ |||||||||
943
+
944
+
945
+ 2300 4600 6900
946
+
947
+
948
+ 2300 4600 6900
949
+
950
+ |Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 1 20 40 60 80 44.0%10 4E+00 0 5E-01 Error 22.0% Test 5E-01|Col2|Col3|Col4|
951
+ |---|---|---|---|
952
+ |||||
953
+ |||||
954
+
955
+
956
+
957
+ 2900 5800 8700
958
+
959
+
960
+ 0.90.80.70.60.50.40.30.20.110
961
+
962
+
963
+ 0
964
+
965
+ 22.0%
966
+
967
+ Test Error
968
+
969
+ 11.0%
970
+
971
+
972
+ 1.4E+001
973
+
974
+ 4.5E-01
975
+
976
+
977
+ 1.5E-01
978
+
979
+ 5.0E-02
980
+
981
+ |20|40|60|
982
+ |---|---|---|
983
+ ||||
984
+ ||||
985
+ ||||
986
+
987
+
988
+ 2900 5800 8700
989
+
990
+
991
+
992
+ (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
993
+
994
+ 76.0%
995
+
996
+
997
+ 3.6E+00
998
+
999
+ 1.2E+00
1000
+
1001
+
1002
+ 4.0E-01
1003
+
1004
+
1005
+ 38.0%
1006
+
1007
+ |Time (s)|Col2|Col3|
1008
+ |---|---|---|
1009
+ |(c) CIFAR-10 Training Loss.|||
1010
+ ||||
1011
+ ||||
1012
+
1013
+ |Time (s)|Col2|Col3|
1014
+ |---|---|---|
1015
+ |(d) CIFAR-10 Test Error.|||
1016
+ ||||
1017
+
1018
+
1019
+ 2900 5800 8700
1020
+
1021
+ Time (s)
1022
+
1023
+
1024
+ 2900 5800 8700
1025
+
1026
+ Time (s)
1027
+
1028
+
1029
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
1030
+
1031
+ Figure 9: Convergence curves of four batch selection strategies using ResNet with SGD.
1032
+
1033
+
1034
+ -----
1035
+
ai_scientist/fewshot_examples/attention.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"The paper proposes the Transformer, a novel neural network architecture that relies entirely on self-attention mechanisms, eschewing traditional recurrent and convolutional layers. This innovation allows the model to achieve state-of-the-art results in machine translation tasks with significant improvements in both training efficiency and translation quality. The paper includes detailed descriptions of the model architecture, including multi-head attention and positional encodings, as well as extensive experimental results to validate the model's performance.\",\n \"Questions\": [\n \"Could the authors provide more detailed comparisons with other recent models not included in Table 2?\",\n \"What is the impact of varying the number of layers (N) in both the encoder and decoder stacks?\",\n \"Can the authors provide more insights into the choice of hyperparameters, especially the learning rate schedule and warmup steps?\"\n ],\n \"Limitations\": [\n \"The paper does not explore the application of the Transformer to tasks beyond machine translation, such as image or audio processing.\",\n \"The discussion on the potential negative societal impacts of the model is minimal and could be expanded.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 4,\n \"Presentation\": 3,\n \"Contribution\": 4,\n \"Overall\": 8,\n \"Confidence\": 5,\n \"Strengths\": [\n \"The Transformer model introduces a highly innovative use of self-attention mechanisms, replacing traditional recurrent and convolutional layers.\",\n \"Comprehensive experimental validation showing state-of-the-art performance in machine translation tasks.\",\n \"Clear and detailed description of the model architecture and its components, facilitating reproducibility and further research.\"\n ],\n \"Weaknesses\": [\n \"Limited discussion on the application of the model to other domains beyond machine translation.\",\n \"The paper could benefit from a deeper analysis of the potential negative societal impacts of the model.\"\n ],\n \"Originality\": 4,\n \"Quality\": 4,\n \"Clarity\": 4,\n \"Significance\": 4,\n \"Decision\": \"Accept\"\n}"
3
+ }
ai_scientist/fewshot_examples/attention.pdf ADDED
Binary file (569 kB). View file
 
ai_scientist/fewshot_examples/attention.txt ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attention Is All You Need
2
+
3
+
4
+ **Ashish Vaswani[∗]**
5
+ Google Brain
6
+ ```
7
+ avaswani@google.com
8
+
9
+ ```
10
+ **Llion Jones[∗]**
11
+ Google Research
12
+ ```
13
+ llion@google.com
14
+
15
+ ```
16
+
17
+ **Noam Shazeer[∗]**
18
+ Google Brain
19
+ ```
20
+ noam@google.com
21
+
22
+ ```
23
+
24
+ **Niki Parmar[∗]**
25
+ Google Research
26
+ ```
27
+ nikip@google.com
28
+
29
+ ```
30
+
31
+ **Jakob Uszkoreit[∗]**
32
+ Google Research
33
+ ```
34
+ usz@google.com
35
+
36
+ ```
37
+
38
+ **Aidan N. Gomez[∗†]**
39
+ University of Toronto
40
+ ```
41
+ aidan@cs.toronto.edu
42
+
43
+ ```
44
+
45
+ **Łukasz Kaiser[∗]**
46
+ Google Brain
47
+ ```
48
+ lukaszkaiser@google.com
49
+
50
+ ```
51
+
52
+ **Illia Polosukhin[∗‡]**
53
+ ```
54
+ illia.polosukhin@gmail.com
55
+
56
+ ```
57
+ **Abstract**
58
+
59
+ The dominant sequence transduction models are based on complex recurrent or
60
+ convolutional neural networks that include an encoder and a decoder. The best
61
+ performing models also connect the encoder and decoder through an attention
62
+ mechanism. We propose a new simple network architecture, the Transformer,
63
+ based solely on attention mechanisms, dispensing with recurrence and convolutions
64
+ entirely. Experiments on two machine translation tasks show these models to
65
+ be superior in quality while being more parallelizable and requiring significantly
66
+ less time to train. Our model achieves 28.4 BLEU on the WMT 2014 Englishto-German translation task, improving over the existing best results, including
67
+ ensembles, by over 2 BLEU. On the WMT 2014 English-to-French translation task,
68
+ our model establishes a new single-model state-of-the-art BLEU score of 41.0 after
69
+ training for 3.5 days on eight GPUs, a small fraction of the training costs of the
70
+ best models from the literature.
71
+
72
+ **1** **Introduction**
73
+
74
+ Recurrent neural networks, long short-term memory [12] and gated recurrent [7] neural networks
75
+ in particular, have been firmly established as state of the art approaches in sequence modeling and
76
+ transduction problems such as language modeling and machine translation [29, 2, 5]. Numerous
77
+ efforts have since continued to push the boundaries of recurrent language models and encoder-decoder
78
+ architectures [31, 21, 13].
79
+
80
+ _∗Equal contribution. Listing order is random. Jakob proposed replacing RNNs with self-attention and started_
81
+ the effort to evaluate this idea. Ashish, with Illia, designed and implemented the first Transformer models and
82
+ has been crucially involved in every aspect of this work. Noam proposed scaled dot-product attention, multi-head
83
+ attention and the parameter-free position representation and became the other person involved in nearly every
84
+ detail. Niki designed, implemented, tuned and evaluated countless model variants in our original codebase and
85
+ tensor2tensor. Llion also experimented with novel model variants, was responsible for our initial codebase, and
86
+ efficient inference and visualizations. Lukasz and Aidan spent countless long days designing various parts of and
87
+ implementing tensor2tensor, replacing our earlier codebase, greatly improving results and massively accelerating
88
+ our research.
89
+ _†Work performed while at Google Brain._
90
+ _‡Work performed while at Google Research._
91
+
92
+ 31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.
93
+
94
+
95
+ -----
96
+
97
+ Recurrent models typically factor computation along the symbol positions of the input and output
98
+ sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden
99
+ states ht, as a function of the previous hidden state ht 1 and the input for position t. This inherently
100
+ _−_
101
+ sequential nature precludes parallelization within training examples, which becomes critical at longer
102
+ sequence lengths, as memory constraints limit batching across examples. Recent work has achieved
103
+ significant improvements in computational efficiency through factorization tricks [18] and conditional
104
+ computation [26], while also improving model performance in case of the latter. The fundamental
105
+ constraint of sequential computation, however, remains.
106
+
107
+ Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in
108
+ the input or output sequences [2, 16]. In all but a few cases [22], however, such attention mechanisms
109
+ are used in conjunction with a recurrent network.
110
+
111
+ In this work we propose the Transformer, a model architecture eschewing recurrence and instead
112
+ relying entirely on an attention mechanism to draw global dependencies between input and output.
113
+ The Transformer allows for significantly more parallelization and can reach a new state of the art in
114
+ translation quality after being trained for as little as twelve hours on eight P100 GPUs.
115
+
116
+ **2** **Background**
117
+
118
+ The goal of reducing sequential computation also forms the foundation of the Extended Neural GPU
119
+
120
+ [20], ByteNet [15] and ConvS2S [8], all of which use convolutional neural networks as basic building
121
+ block, computing hidden representations in parallel for all input and output positions. In these models,
122
+ the number of operations required to relate signals from two arbitrary input or output positions grows
123
+ in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes
124
+ it more difficult to learn dependencies between distant positions [11]. In the Transformer this is
125
+ reduced to a constant number of operations, albeit at the cost of reduced effective resolution due
126
+ to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as
127
+ described in section 3.2.
128
+
129
+ Self-attention, sometimes called intra-attention is an attention mechanism relating different positions
130
+ of a single sequence in order to compute a representation of the sequence. Self-attention has been
131
+ used successfully in a variety of tasks including reading comprehension, abstractive summarization,
132
+ textual entailment and learning task-independent sentence representations [4, 22, 23, 19].
133
+
134
+ End-to-end memory networks are based on a recurrent attention mechanism instead of sequencealigned recurrence and have been shown to perform well on simple-language question answering and
135
+ language modeling tasks [28].
136
+
137
+ To the best of our knowledge, however, the Transformer is the first transduction model relying
138
+ entirely on self-attention to compute representations of its input and output without using sequencealigned RNNs or convolution. In the following sections, we will describe the Transformer, motivate
139
+ self-attention and discuss its advantages over models such as [14, 15] and [8].
140
+
141
+ **3** **Model Architecture**
142
+
143
+ Most competitive neural sequence transduction models have an encoder-decoder structure [5, 2, 29].
144
+ Here, the encoder maps an input sequence of symbol representations (x1, ..., xn) to a sequence
145
+ of continuous representations z = (z1, ..., zn). Given z, the decoder then generates an output
146
+ sequence (y1, ..., ym) of symbols one element at a time. At each step the model is auto-regressive
147
+
148
+ [9], consuming the previously generated symbols as additional input when generating the next.
149
+
150
+ The Transformer follows this overall architecture using stacked self-attention and point-wise, fully
151
+ connected layers for both the encoder and decoder, shown in the left and right halves of Figure 1,
152
+ respectively.
153
+
154
+ **3.1** **Encoder and Decoder Stacks**
155
+
156
+ **Encoder:** The encoder is composed of a stack of N = 6 identical layers. Each layer has two
157
+ sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position
158
+
159
+ -----
160
+
161
+ Figure 1: The Transformer - model architecture.
162
+
163
+ wise fully connected feed-forward network. We employ a residual connection [10] around each of
164
+ the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is
165
+ LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer
166
+ itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding
167
+ layers, produce outputs of dimension dmodel = 512.
168
+
169
+ **Decoder:** The decoder is also composed of a stack of N = 6 identical layers. In addition to the two
170
+ sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head
171
+ attention over the output of the encoder stack. Similar to the encoder, we employ residual connections
172
+ around each of the sub-layers, followed by layer normalization. We also modify the self-attention
173
+ sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This
174
+ masking, combined with fact that the output embeddings are offset by one position, ensures that the
175
+ predictions for position i can depend only on the known outputs at positions less than i.
176
+
177
+ **3.2** **Attention**
178
+
179
+ An attention function can be described as mapping a query and a set of key-value pairs to an output,
180
+ where the query, keys, values, and output are all vectors. The output is computed as a weighted sum
181
+ of the values, where the weight assigned to each value is computed by a compatibility function of the
182
+ query with the corresponding key.
183
+
184
+ **3.2.1** **Scaled Dot-Product Attention**
185
+
186
+ We call our particular attention "Scaled Dot-Product Attention" (Figure 2). The input consists of
187
+ queries and keys of dimension dk, and values of dimension dv. We compute the dot products of the
188
+
189
+
190
+ -----
191
+
192
+ Scaled Dot-Product Attention Multi-Head Attention
193
+
194
+ Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several
195
+ attention layers running in parallel.
196
+
197
+ query with all keys, divide each by _dk, and apply a softmax function to obtain the weights on the_
198
+
199
+ _[√]_
200
+ values.
201
+
202
+ In practice, we compute the attention function on a set of queries simultaneously, packed together
203
+ into a matrix Q. The keys and values are also packed together into matrices K and V . We compute
204
+ the matrix of outputs as:
205
+
206
+ Attention(Q, K, V ) = softmax( _[QK]√dk[T]_ )V (1)
207
+
208
+ The two most commonly used attention functions are additive attention [2], and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor
209
+ of _√1dk . Additive attention computes the compatibility function using a feed-forward network with_
210
+ a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is
211
+ much faster and more space-efficient in practice, since it can be implemented using highly optimized
212
+ matrix multiplication code.
213
+
214
+ While for small values of dk the two mechanisms perform similarly, additive attention outperforms
215
+ dot product attention without scaling for larger values of dk [3]. We suspect that for large values of
216
+ _dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has_
217
+ extremely small gradients [4]. To counteract this effect, we scale the dot products by _√1dk ._
218
+
219
+ **3.2.2** **Multi-Head Attention**
220
+
221
+ Instead of performing a single attention function with dmodel-dimensional keys, values and queries,
222
+ we found it beneficial to linearly project the queries, keys and values h times with different, learned
223
+ linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of
224
+ queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional
225
+ output values. These are concatenated and once again projected, resulting in the final values, as
226
+ depicted in Figure 2.
227
+
228
+ Multi-head attention allows the model to jointly attend to information from different representation
229
+ subspaces at different positions. With a single attention head, averaging inhibits this.
230
+
231
+ 4To illustrate why the dot products get large, assume that the components of q and k are independent random
232
+ variables with mean 0 and variance 1. Then their dot product, q · k = _i=1_ _[q][i][k][i][, has mean][ 0][ and variance][ d][k][.]_
233
+
234
+ [P][d][k]
235
+
236
+
237
+ -----
238
+
239
+ MultiHead(Q, K, V ) = Concat(head1, ..., headh)W _[O]_
240
+
241
+ where headi = Attention(QWi[Q][, KW][ K]i _[, V W][ V]i_ [)]
242
+
243
+ Where the projections are parameter matrices Wi[Q] R[d][model][×][d][k], Wi[K] R[d][model][×][d][k], Wi[V] R[d][model][×][d][v]
244
+ _∈_ _∈_ _∈_
245
+ and W _[O]_ _∈_ R[hd][v][×][d][model].
246
+
247
+ In this work we employ h = 8 parallel attention layers, or heads. For each of these we use
248
+ _dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost_
249
+ is similar to that of single-head attention with full dimensionality.
250
+
251
+ **3.2.3** **Applications of Attention in our Model**
252
+
253
+ The Transformer uses multi-head attention in three different ways:
254
+
255
+ _• In "encoder-decoder attention" layers, the queries come from the previous decoder layer,_
256
+ and the memory keys and values come from the output of the encoder. This allows every
257
+ position in the decoder to attend over all positions in the input sequence. This mimics the
258
+ typical encoder-decoder attention mechanisms in sequence-to-sequence models such as
259
+
260
+ [31, 2, 8].
261
+
262
+ _• The encoder contains self-attention layers. In a self-attention layer all of the keys, values_
263
+ and queries come from the same place, in this case, the output of the previous layer in the
264
+ encoder. Each position in the encoder can attend to all positions in the previous layer of the
265
+ encoder.
266
+
267
+ _• Similarly, self-attention layers in the decoder allow each position in the decoder to attend to_
268
+ all positions in the decoder up to and including that position. We need to prevent leftward
269
+ information flow in the decoder to preserve the auto-regressive property. We implement this
270
+ inside of scaled dot-product attention by masking out (setting to −∞) all values in the input
271
+ of the softmax which correspond to illegal connections. See Figure 2.
272
+
273
+ **3.3** **Position-wise Feed-Forward Networks**
274
+
275
+ In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully
276
+ connected feed-forward network, which is applied to each position separately and identically. This
277
+ consists of two linear transformations with a ReLU activation in between.
278
+
279
+ FFN(x) = max(0, xW1 + b1)W2 + b2 (2)
280
+
281
+ While the linear transformations are the same across different positions, they use different parameters
282
+ from layer to layer. Another way of describing this is as two convolutions with kernel size 1.
283
+ The dimensionality of input and output is dmodel = 512, and the inner-layer has dimensionality
284
+ _dff = 2048._
285
+
286
+ **3.4** **Embeddings and Softmax**
287
+
288
+ Similarly to other sequence transduction models, we use learned embeddings to convert the input
289
+ tokens and output tokens to vectors of dimension dmodel. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In
290
+ our model, we share the same weight matrix between the two embedding layers and the pre-softmax
291
+ linear transformation, similar to [24]. In the embedding layers, we multiply those weights by _dmodel._
292
+
293
+ _[√]_
294
+
295
+ **3.5** **Positional Encoding**
296
+
297
+ Since our model contains no recurrence and no convolution, in order for the model to make use of the
298
+ order of the sequence, we must inject some information about the relative or absolute position of the
299
+ tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the
300
+
301
+
302
+ -----
303
+
304
+ Table 1: Maximum path lengths, per-layer complexity and minimum number of sequential operations
305
+ for different layer types. n is the sequence length, d is the representation dimension, k is the kernel
306
+ size of convolutions and r the size of the neighborhood in restricted self-attention.
307
+
308
+ Layer Type Complexity per Layer Sequential Maximum Path Length
309
+ Operations
310
+
311
+ Self-Attention _O(n[2]_ _· d)_ _O(1)_ _O(1)_
312
+ Recurrent _O(n · d[2])_ _O(n)_ _O(n)_
313
+ Convolutional _O(k_ _n_ _d[2])_ _O(1)_ _O(logk(n))_
314
+ _·_ _·_
315
+ Self-Attention (restricted) _O(r · n · d)_ _O(1)_ _O(n/r)_
316
+
317
+ bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel
318
+ as the embeddings, so that the two can be summed. There are many choices of positional encodings,
319
+ learned and fixed [8].
320
+
321
+ In this work, we use sine and cosine functions of different frequencies:
322
+
323
+ _PE(pos,2i) = sin(pos/10000[2][i/d][model])_
324
+
325
+ _PE(pos,2i+1) = cos(pos/10000[2][i/d][model])_
326
+
327
+ where pos is the position and i is the dimension. That is, each dimension of the positional encoding
328
+ corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We
329
+ chose this function because we hypothesized it would allow the model to easily learn to attend by
330
+ relative positions, since for any fixed offset k, PEpos+k can be represented as a linear function of
331
+ _PEpos._
332
+
333
+ We also experimented with using learned positional embeddings [8] instead, and found that the two
334
+ versions produced nearly identical results (see Table 3 row (E)). We chose the sinusoidal version
335
+ because it may allow the model to extrapolate to sequence lengths longer than the ones encountered
336
+ during training.
337
+
338
+ **4** **Why Self-Attention**
339
+
340
+ In this section we compare various aspects of self-attention layers to the recurrent and convolutional layers commonly used for mapping one variable-length sequence of symbol representations
341
+ (layer in a typical sequence transduction encoder or decoder. Motivating our use of self-attention wex1, ..., xn) to another sequence of equal length (z1, ..., zn), with xi, zi ∈ R[d], such as a hidden
342
+ consider three desiderata.
343
+
344
+ One is the total computational complexity per layer. Another is the amount of computation that can
345
+ be parallelized, as measured by the minimum number of sequential operations required.
346
+
347
+ The third is the path length between long-range dependencies in the network. Learning long-range
348
+ dependencies is a key challenge in many sequence transduction tasks. One key factor affecting the
349
+ ability to learn such dependencies is the length of the paths forward and backward signals have to
350
+ traverse in the network. The shorter these paths between any combination of positions in the input
351
+ and output sequences, the easier it is to learn long-range dependencies [11]. Hence we also compare
352
+ the maximum path length between any two input and output positions in networks composed of the
353
+ different layer types.
354
+
355
+ As noted in Table 1, a self-attention layer connects all positions with a constant number of sequentially
356
+ executed operations, whereas a recurrent layer requires O(n) sequential operations. In terms of
357
+ computational complexity, self-attention layers are faster than recurrent layers when the sequence
358
+ length n is smaller than the representation dimensionality d, which is most often the case with
359
+ sentence representations used by state-of-the-art models in machine translations, such as word-piece
360
+
361
+ [31] and byte-pair [25] representations. To improve computational performance for tasks involving
362
+ very long sequences, self-attention could be restricted to considering only a neighborhood of size r in
363
+
364
+
365
+ -----
366
+
367
+ the input sequence centered around the respective output position. This would increase the maximum
368
+ path length to O(n/r). We plan to investigate this approach further in future work.
369
+
370
+ A single convolutional layer with kernel width k < n does not connect all pairs of input and output
371
+ positions. Doing so requires a stack of O(n/k) convolutional layers in the case of contiguous kernels,
372
+ or O(logk(n)) in the case of dilated convolutions [15], increasing the length of the longest paths
373
+ between any two positions in the network. Convolutional layers are generally more expensive than
374
+ recurrent layers, by a factor of k. Separable convolutions [6], however, decrease the complexity
375
+ considerably, to O(k · n · d + n · d[2]). Even with k = n, however, the complexity of a separable
376
+ convolution is equal to the combination of a self-attention layer and a point-wise feed-forward layer,
377
+ the approach we take in our model.
378
+
379
+ As side benefit, self-attention could yield more interpretable models. We inspect attention distributions
380
+ from our models and present and discuss examples in the appendix. Not only do individual attention
381
+ heads clearly learn to perform different tasks, many appear to exhibit behavior related to the syntactic
382
+ and semantic structure of the sentences.
383
+
384
+ **5** **Training**
385
+
386
+ This section describes the training regime for our models.
387
+
388
+ **5.1** **Training Data and Batching**
389
+
390
+ We trained on the standard WMT 2014 English-German dataset consisting of about 4.5 million
391
+ sentence pairs. Sentences were encoded using byte-pair encoding [3], which has a shared sourcetarget vocabulary of about 37000 tokens. For English-French, we used the significantly larger WMT
392
+ 2014 English-French dataset consisting of 36M sentences and split tokens into a 32000 word-piece
393
+ vocabulary [31]. Sentence pairs were batched together by approximate sequence length. Each training
394
+ batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000
395
+ target tokens.
396
+
397
+ **5.2** **Hardware and Schedule**
398
+
399
+ We trained our models on one machine with 8 NVIDIA P100 GPUs. For our base models using
400
+ the hyperparameters described throughout the paper, each training step took about 0.4 seconds. We
401
+ trained the base models for a total of 100,000 steps or 12 hours. For our big models,(described on the
402
+ bottom line of table 3), step time was 1.0 seconds. The big models were trained for 300,000 steps
403
+ (3.5 days).
404
+
405
+ **5.3** **Optimizer**
406
+
407
+ We used the Adam optimizer [17] with β1 = 0.9, β2 = 0.98 and ϵ = 10[−][9]. We varied the learning
408
+ rate over the course of training, according to the formula:
409
+
410
+ _lrate = d[−]model[0][.][5]_ (3)
411
+
412
+ _[·][ min(][step][_][num][−][0][.][5][, step][_][num][ ·][ warmup][_][steps][−][1][.][5][)]_
413
+
414
+ This corresponds to increasing the learning rate linearly for the first warmup_steps training steps,
415
+ and decreasing it thereafter proportionally to the inverse square root of the step number. We used
416
+ _warmup_steps = 4000._
417
+
418
+ **5.4** **Regularization**
419
+
420
+ We employ three types of regularization during training:
421
+
422
+ **Residual Dropout** We apply dropout [27] to the output of each sub-layer, before it is added to the
423
+ sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the
424
+ positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of
425
+ _Pdrop = 0.1._
426
+
427
+
428
+ -----
429
+
430
+ Table 2: The Transformer achieves better BLEU scores than previous state-of-the-art models on the
431
+ English-to-German and English-to-French newstest2014 tests at a fraction of the training cost.
432
+
433
+ BLEU Training Cost (FLOPs)
434
+ Model
435
+
436
+ EN-DE EN-FR EN-DE EN-FR
437
+
438
+ ByteNet [15] 23.75
439
+ Deep-Att + PosUnk [32] 39.2 1.0 · 10[20]
440
+
441
+ GNMT + RL [31] 24.6 39.92 2.3 · 10[19] 1.4 · 10[20]
442
+
443
+ ConvS2S [8] 25.16 40.46 9.6 · 10[18] 1.5 · 10[20]
444
+
445
+ MoE [26] 26.03 40.56 2.0 · 10[19] 1.2 · 10[20]
446
+
447
+ Deep-Att + PosUnk Ensemble [32] 40.4 8.0 · 10[20]
448
+
449
+ GNMT + RL Ensemble [31] 26.30 41.16 1.8 · 10[20] 1.1 · 10[21]
450
+
451
+ ConvS2S Ensemble [8] 26.36 **41.29** 7.7 · 10[19] 1.2 · 10[21]
452
+
453
+ Transformer (base model) 27.3 38.1 **3.3 · 10[18]**
454
+
455
+ Transformer (big) **28.4** **41.0** 2.3 · 10[19]
456
+
457
+
458
+ **Label Smoothing** During training, we employed label smoothing of value ϵls = 0.1 [30]. This
459
+ hurts perplexity, as the model learns to be more unsure, but improves accuracy and BLEU score.
460
+
461
+ **6** **Results**
462
+
463
+ **6.1** **Machine Translation**
464
+
465
+ On the WMT 2014 English-to-German translation task, the big transformer model (Transformer (big)
466
+ in Table 2) outperforms the best previously reported models (including ensembles) by more than 2.0
467
+ BLEU, establishing a new state-of-the-art BLEU score of 28.4. The configuration of this model is
468
+ listed in the bottom line of Table 3. Training took 3.5 days on 8 P100 GPUs. Even our base model
469
+ surpasses all previously published models and ensembles, at a fraction of the training cost of any of
470
+ the competitive models.
471
+
472
+ On the WMT 2014 English-to-French translation task, our big model achieves a BLEU score of 41.0,
473
+ outperforming all of the previously published single models, at less than 1/4 the training cost of the
474
+ previous state-of-the-art model. The Transformer (big) model trained for English-to-French used
475
+ dropout rate Pdrop = 0.1, instead of 0.3.
476
+
477
+ For the base models, we used a single model obtained by averaging the last 5 checkpoints, which
478
+ were written at 10-minute intervals. For the big models, we averaged the last 20 checkpoints. We
479
+ used beam search with a beam size of 4 and length penalty α = 0.6 [31]. These hyperparameters
480
+ were chosen after experimentation on the development set. We set the maximum output length during
481
+ inference to input length + 50, but terminate early when possible [31].
482
+
483
+ Table 2 summarizes our results and compares our translation quality and training costs to other model
484
+ architectures from the literature. We estimate the number of floating point operations used to train a
485
+ model by multiplying the training time, the number of GPUs used, and an estimate of the sustained
486
+ single-precision floating-point capacity of each GPU [5].
487
+
488
+ **6.2** **Model Variations**
489
+
490
+ To evaluate the importance of different components of the Transformer, we varied our base model
491
+ in different ways, measuring the change in performance on English-to-German translation on the
492
+ development set, newstest2013. We used beam search as described in the previous section, but no
493
+ checkpoint averaging. We present these results in Table 3.
494
+
495
+ In Table 3 rows (A), we vary the number of attention heads and the attention key and value dimensions,
496
+ keeping the amount of computation constant, as described in Section 3.2.2. While single-head
497
+ attention is 0.9 BLEU worse than the best setting, quality also drops off with too many heads.
498
+
499
+ 5We used values of 2.8, 3.7, 6.0 and 9.5 TFLOPS for K80, K40, M40 and P100, respectively.
500
+
501
+
502
+ -----
503
+
504
+ Table 3: Variations on the Transformer architecture. Unlisted values are identical to those of the base
505
+ model. All metrics are on the English-to-German translation development set, newstest2013. Listed
506
+ perplexities are per-wordpiece, according to our byte-pair encoding, and should not be compared to
507
+ per-word perplexities.
508
+
509
+ |Col1|train N d d h d d P ϵ model ff k v drop ls steps|PPL BLEU params (dev) (dev) 106 ×|
510
+ |---|---|---|
511
+ |base|6 512 2048 8 64 64 0.1 0.1 100K|4.92 25.8 65|
512
+ |(A)|1 512 512 4 128 128 16 32 32 32 16 16|5.29 24.9 5.00 25.5 4.91 25.8 5.01 25.4|
513
+ |(B)|16 32|5.16 25.1 58 5.01 25.4 60|
514
+ |(C)|2 4 8 256 32 32 1024 128 128 1024 4096|6.11 23.7 36 5.19 25.3 50 4.88 25.5 80 5.75 24.5 28 4.66 26.0 168 5.12 25.4 53 4.75 26.2 90|
515
+ |(D)|0.0 0.2 0.0 0.2|5.77 24.6 4.95 25.5 4.67 25.3 5.47 25.7|
516
+ |(E)|positional embedding instead of sinusoids|4.92 25.7|
517
+ |big|6 1024 4096 16 0.3 300K|4.33 26.4 213|
518
+
519
+
520
+
521
+ In Table 3 rows (B), we observe that reducing the attention key size dk hurts model quality. This
522
+ suggests that determining compatibility is not easy and that a more sophisticated compatibility
523
+ function than dot product may be beneficial. We further observe in rows (C) and (D) that, as expected,
524
+ bigger models are better, and dropout is very helpful in avoiding over-fitting. In row (E) we replace our
525
+ sinusoidal positional encoding with learned positional embeddings [8], and observe nearly identical
526
+ results to the base model.
527
+
528
+ **7** **Conclusion**
529
+
530
+ In this work, we presented the Transformer, the first sequence transduction model based entirely on
531
+ attention, replacing the recurrent layers most commonly used in encoder-decoder architectures with
532
+ multi-headed self-attention.
533
+
534
+ For translation tasks, the Transformer can be trained significantly faster than architectures based
535
+ on recurrent or convolutional layers. On both WMT 2014 English-to-German and WMT 2014
536
+ English-to-French translation tasks, we achieve a new state of the art. In the former task our best
537
+ model outperforms even all previously reported ensembles.
538
+
539
+ We are excited about the future of attention-based models and plan to apply them to other tasks. We
540
+ plan to extend the Transformer to problems involving input and output modalities other than text and
541
+ to investigate local, restricted attention mechanisms to efficiently handle large inputs and outputs
542
+ such as images, audio and video. Making generation less sequential is another research goals of ours.
543
+
544
+ [The code we used to train and evaluate our models is available at https://github.com/](https://github.com/tensorflow/tensor2tensor)
545
+ ```
546
+ tensorflow/tensor2tensor.
547
+
548
+ ```
549
+ **Acknowledgements** We are grateful to Nal Kalchbrenner and Stephan Gouws for their fruitful
550
+ comments, corrections and inspiration.
551
+
552
+
553
+ -----
554
+
555
+ **References**
556
+
557
+ [1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint
558
+ _arXiv:1607.06450, 2016._
559
+
560
+ [2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly
561
+ learning to align and translate. CoRR, abs/1409.0473, 2014.
562
+
563
+ [3] Denny Britz, Anna Goldie, Minh-Thang Luong, and Quoc V. Le. Massive exploration of neural
564
+ machine translation architectures. CoRR, abs/1703.03906, 2017.
565
+
566
+ [4] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine
567
+ reading. arXiv preprint arXiv:1601.06733, 2016.
568
+
569
+ [5] Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Fethi Bougares, Holger Schwenk,
570
+ and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical
571
+ machine translation. CoRR, abs/1406.1078, 2014.
572
+
573
+ [6] Francois Chollet. Xception: Deep learning with depthwise separable convolutions. arXiv
574
+ _preprint arXiv:1610.02357, 2016._
575
+
576
+ [7] Junyoung Chung, Çaglar Gülçehre, Kyunghyun Cho, and Yoshua Bengio. Empirical evaluation
577
+ of gated recurrent neural networks on sequence modeling. CoRR, abs/1412.3555, 2014.
578
+
579
+ [8] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin. Convolutional sequence to sequence learning. arXiv preprint arXiv:1705.03122v2, 2017.
580
+
581
+ [9] Alex Graves. Generating sequences with recurrent neural networks. _arXiv preprint_
582
+ _arXiv:1308.0850, 2013._
583
+
584
+ [10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern
585
+ _Recognition, pages 770–778, 2016._
586
+
587
+ [11] Sepp Hochreiter, Yoshua Bengio, Paolo Frasconi, and Jürgen Schmidhuber. Gradient flow in
588
+ recurrent nets: the difficulty of learning long-term dependencies, 2001.
589
+
590
+ [12] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation,
591
+ 9(8):1735–1780, 1997.
592
+
593
+ [13] Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, and Yonghui Wu. Exploring
594
+ the limits of language modeling. arXiv preprint arXiv:1602.02410, 2016.
595
+
596
+ [14] Łukasz Kaiser and Ilya Sutskever. Neural GPUs learn algorithms. In International Conference
597
+ _on Learning Representations (ICLR), 2016._
598
+
599
+ [15] Nal Kalchbrenner, Lasse Espeholt, Karen Simonyan, Aaron van den Oord, Alex Graves, and Koray Kavukcuoglu. Neural machine translation in linear time. arXiv preprint arXiv:1610.10099v2,
600
+ 2017.
601
+
602
+ [16] Yoon Kim, Carl Denton, Luong Hoang, and Alexander M. Rush. Structured attention networks.
603
+ In International Conference on Learning Representations, 2017.
604
+
605
+ [17] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
606
+
607
+ [18] Oleksii Kuchaiev and Boris Ginsburg. Factorization tricks for LSTM networks. arXiv preprint
608
+ _arXiv:1703.10722, 2017._
609
+
610
+ [19] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen
611
+ Zhou, and Yoshua Bengio. A structured self-attentive sentence embedding. arXiv preprint
612
+ _arXiv:1703.03130, 2017._
613
+
614
+ [20] Samy Bengio Łukasz Kaiser. Can active memory replace attention? In Advances in Neural
615
+ _Information Processing Systems, (NIPS), 2016._
616
+
617
+
618
+ -----
619
+
620
+ [21] Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attentionbased neural machine translation. arXiv preprint arXiv:1508.04025, 2015.
621
+
622
+ [22] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention
623
+ model. In Empirical Methods in Natural Language Processing, 2016.
624
+
625
+ [23] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive
626
+ summarization. arXiv preprint arXiv:1705.04304, 2017.
627
+
628
+ [24] Ofir Press and Lior Wolf. Using the output embedding to improve language models. arXiv
629
+ _preprint arXiv:1608.05859, 2016._
630
+
631
+ [25] Rico Sennrich, Barry Haddow, and Alexandra Birch. Neural machine translation of rare words
632
+ with subword units. arXiv preprint arXiv:1508.07909, 2015.
633
+
634
+ [26] Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton,
635
+ and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts
636
+ layer. arXiv preprint arXiv:1701.06538, 2017.
637
+
638
+ [27] Nitish Srivastava, Geoffrey E Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine
639
+ _Learning Research, 15(1):1929–1958, 2014._
640
+
641
+ [28] Sainbayar Sukhbaatar, arthur szlam, Jason Weston, and Rob Fergus. End-to-end memory
642
+ networks. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors,
643
+ _Advances in Neural Information Processing Systems 28, pages 2440–2448. Curran Associates,_
644
+ Inc., 2015.
645
+
646
+ [29] Ilya Sutskever, Oriol Vinyals, and Quoc VV Le. Sequence to sequence learning with neural
647
+ networks. In Advances in Neural Information Processing Systems, pages 3104–3112, 2014.
648
+
649
+ [30] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna.
650
+ Rethinking the inception architecture for computer vision. CoRR, abs/1512.00567, 2015.
651
+
652
+ [31] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang
653
+ Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine
654
+ translation system: Bridging the gap between human and machine translation. arXiv preprint
655
+ _arXiv:1609.08144, 2016._
656
+
657
+ [32] Jie Zhou, Ying Cao, Xuguang Wang, Peng Li, and Wei Xu. Deep recurrent models with
658
+ fast-forward connections for neural machine translation. CoRR, abs/1606.04199, 2016.
659
+
660
+
661
+ -----
662
+
ai_scientist/generate_ideas.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from typing import Dict, List, Union
6
+
7
+ import backoff
8
+ import requests
9
+ from strictjson import strict_json
10
+
11
+ from ai_scientist.llm import (
12
+ allchoices,
13
+ extract_json_between_markers,
14
+ get_response_from_llm,
15
+ llm_json_auto_correct,
16
+ )
17
+
18
+ S2_API_KEY = os.getenv("S2_API_KEY")
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+
21
+
22
+ idea_first_prompt = """{task_description}
23
+ <experiment.py>
24
+ {code}
25
+ </experiment.py>
26
+
27
+ Here are the ideas that you have already generated:
28
+
29
+ '''
30
+ {prev_ideas_string}
31
+ '''
32
+
33
+ Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
34
+ Note that you will not have access to any additional resources or datasets.
35
+ Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
36
+
37
+ Respond in the following format:
38
+
39
+ THOUGHT:
40
+ <THOUGHT>
41
+
42
+ NEW IDEA JSON:
43
+ ```json
44
+ <JSON>
45
+ ```
46
+
47
+ In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
48
+
49
+ Add '```json' before the <JSON> and '```' after the <JSON> as above. In <JSON>, provide the new idea in JSON format with the following keys and values:
50
+ - "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
51
+ - "Title": A title for the idea, will be used for the report writing.
52
+ - "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
53
+ - "Interestingness": A rating from 1 to 10 (lowest to highest).
54
+ - "Feasibility": A rating from 1 to 10 (lowest to highest).
55
+ - "Novelty": A rating from 1 to 10 (lowest to highest).
56
+
57
+ Be cautious and realistic on your ratings.
58
+ This JSON will be automatically parsed, so ensure the format is precise.
59
+ You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
60
+ """
61
+
62
+ idea_reflection_prompt = """Round {current_round}/{num_reflections}.
63
+ In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
64
+ Include any other factors that you think are important in evaluating the idea.
65
+ Ensure the idea is clear and concise, and the JSON is the correct format.
66
+ Do not make things overly complicated.
67
+ In the next attempt, try and refine and improve your idea.
68
+ Stick to the spirit of the original idea unless there are glaring issues.
69
+
70
+ Respond in the exactly the same format as before:
71
+ THOUGHT:
72
+ <THOUGHT>
73
+
74
+ NEW IDEA JSON:
75
+ ```json
76
+ <JSON>
77
+ ```
78
+
79
+ If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
80
+ ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES.
81
+ """
82
+
83
+
84
+ # Format the content in JSON
85
+ def format_idea_json(text):
86
+ json_start_marker = "```json"
87
+ json_end_marker = "```"
88
+ start_index = text.find(json_start_marker)
89
+ if start_index != -1:
90
+ start_index += len(json_start_marker) # Move past the marker
91
+ end_index = text.find(json_end_marker, start_index)
92
+ json_string = text[start_index:end_index].strip()
93
+ res = strict_json(
94
+ system_prompt="You are a JSON formatter",
95
+ user_prompt=json_string,
96
+ output_format={
97
+ "Name": "A shortened descriptor of the idea",
98
+ "Title": "A title for the idea, will be used for the report writing",
99
+ "Experiment": "An outline of the implementation, type: list",
100
+ "Interestingness": "A rating from 1 to 10 (lowest to highest), type: int",
101
+ "Feasibility": "A rating from 1 to 10 (lowest to highest), type: int",
102
+ "Novelty": "A rating from 1 to 10 (lowest to highest), type: int",
103
+ },
104
+ llm=llm_json_auto_correct,
105
+ )
106
+ text = "```json\n" + json.dumps(res) + "```\n"
107
+ return text
108
+
109
+
110
+ def format_novelty_json(text):
111
+ json_start_marker = "```json"
112
+ json_end_marker = "```"
113
+ start_index = text.find(json_start_marker)
114
+ if start_index != -1:
115
+ start_index += len(json_start_marker) # Move past the marker
116
+ end_index = text.find(json_end_marker, start_index)
117
+ json_string = text[start_index:end_index].strip()
118
+ res = strict_json(
119
+ system_prompt="You are a JSON formatter",
120
+ user_prompt=json_string,
121
+ output_format={
122
+ "Query": "An optional search query to search the literature (e.g. attention is all you need)",
123
+ },
124
+ llm=llm_json_auto_correct,
125
+ )
126
+ text = "```json\n" + json.dumps(res) + "```\n"
127
+ return text
128
+
129
+
130
+ # GENERATE IDEAS
131
+ def generate_ideas(
132
+ base_dir,
133
+ client,
134
+ model,
135
+ skip_generation=False,
136
+ max_num_generations=20,
137
+ num_reflections=5,
138
+ ):
139
+ if skip_generation:
140
+ # Load existing ideas from file
141
+ try:
142
+ with open(osp.join(base_dir, "ideas.json"), "r") as f:
143
+ ideas = json.load(f)
144
+ print("Loaded existing ideas:")
145
+ for idea in ideas:
146
+ print(idea)
147
+ return ideas
148
+ except FileNotFoundError:
149
+ print("No existing ideas found. Generating new ideas.")
150
+ except json.JSONDecodeError:
151
+ print("Error decoding existing ideas. Generating new ideas.")
152
+
153
+ idea_str_archive = []
154
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
155
+ seed_ideas = json.load(f)
156
+ for seed_idea in seed_ideas:
157
+ idea_str_archive.append(json.dumps(seed_idea))
158
+
159
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
160
+ code = f.read()
161
+
162
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
163
+ prompt = json.load(f)
164
+
165
+ idea_system_prompt = prompt["system"]
166
+
167
+ for _ in range(max_num_generations):
168
+ print()
169
+ print(f"Generating idea {_ + 1}/{max_num_generations}")
170
+ import traceback
171
+ try:
172
+ prev_ideas_string = "\n\n".join(idea_str_archive)
173
+
174
+ msg_history = []
175
+ print(f"Iteration 1/{num_reflections}")
176
+ text, msg_history = get_response_from_llm(
177
+ idea_first_prompt.format(
178
+ task_description=prompt["task_description"],
179
+ code=code,
180
+ prev_ideas_string=prev_ideas_string,
181
+ num_reflections=num_reflections,
182
+ ),
183
+ client=client,
184
+ model=model,
185
+ system_message=idea_system_prompt,
186
+ msg_history=msg_history,
187
+ )
188
+ ## Format the content in JSON
189
+ text = format_idea_json(text)
190
+
191
+ ## PARSE OUTPUT
192
+ json_output = extract_json_between_markers(text)
193
+ assert json_output is not None, "Failed to extract JSON from LLM output"
194
+ # print(json_output)
195
+
196
+ # Iteratively improve task.
197
+ if num_reflections > 1:
198
+ for j in range(num_reflections - 1):
199
+ print(f"Iteration {j + 2}/{num_reflections}")
200
+ text, msg_history = get_response_from_llm(
201
+ idea_reflection_prompt.format(
202
+ current_round=j + 2, num_reflections=num_reflections
203
+ ),
204
+ client=client,
205
+ model=model,
206
+ system_message=idea_system_prompt,
207
+ msg_history=msg_history,
208
+ )
209
+ ## Format the content in JSON if using weak LLM
210
+ text = format_idea_json(text)
211
+ ## PARSE OUTPUT
212
+ json_output = extract_json_between_markers(text)
213
+ assert (
214
+ json_output is not None
215
+ ), "Failed to extract JSON from LLM output"
216
+ # print(json_output)
217
+
218
+ if "I am done" in text:
219
+ print(f"Idea generation converged after {j + 2} iterations.")
220
+ break
221
+
222
+ idea_str_archive.append(json.dumps(json_output))
223
+ except Exception as e:
224
+ print(f"Failed to generate idea: {e}")
225
+ traceback.print_exc()
226
+ continue
227
+
228
+ ## SAVE IDEAS
229
+ ideas = []
230
+ for idea_str in idea_str_archive:
231
+ ideas.append(json.loads(idea_str))
232
+
233
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
234
+ json.dump(ideas, f, indent=4)
235
+
236
+ return ideas
237
+
238
+
239
+ # GENERATE IDEAS OPEN-ENDED
240
+ def generate_next_idea(
241
+ base_dir,
242
+ client,
243
+ model,
244
+ prev_idea_archive=[],
245
+ num_reflections=5,
246
+ max_attempts=10,
247
+ ):
248
+ idea_archive = prev_idea_archive
249
+ original_archive_size = len(idea_archive)
250
+
251
+ print(f"Generating idea {original_archive_size + 1}")
252
+
253
+ if len(prev_idea_archive) == 0:
254
+ print(f"First iteration, taking seed ideas")
255
+ # seed the archive on the first run with pre-existing ideas
256
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
257
+ seed_ideas = json.load(f)
258
+ for seed_idea in seed_ideas[:1]:
259
+ idea_archive.append(seed_idea)
260
+ else:
261
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
262
+ code = f.read()
263
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
264
+ prompt = json.load(f)
265
+ idea_system_prompt = prompt["system"]
266
+
267
+ for _ in range(max_attempts):
268
+ import traceback
269
+ try:
270
+ idea_strings = []
271
+ for idea in idea_archive:
272
+ idea_strings.append(json.dumps(idea))
273
+ prev_ideas_string = "\n\n".join(idea_strings)
274
+
275
+ msg_history = []
276
+ print(f"Iteration 1/{num_reflections}")
277
+ text, msg_history = get_response_from_llm(
278
+ idea_first_prompt.format(
279
+ task_description=prompt["task_description"],
280
+ code=code,
281
+ prev_ideas_string=prev_ideas_string,
282
+ num_reflections=num_reflections,
283
+ )
284
+ + """
285
+ Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
286
+ This is on a standard 1-10 ML conference scale.
287
+ Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
288
+ """,
289
+ client=client,
290
+ model=model,
291
+ system_message=idea_system_prompt,
292
+ msg_history=msg_history,
293
+ )
294
+ ## Format the content in JSON if using weak LLM
295
+ text = format_idea_json(text)
296
+ ## PARSE OUTPUT
297
+ json_output = extract_json_between_markers(text)
298
+ assert json_output is not None, "Failed to extract JSON from LLM output"
299
+ # print(json_output)
300
+
301
+ # Iteratively improve task.
302
+ if num_reflections > 1:
303
+ for j in range(num_reflections - 1):
304
+ print(f"Iteration {j + 2}/{num_reflections}")
305
+ text, msg_history = get_response_from_llm(
306
+ idea_reflection_prompt.format(
307
+ current_round=j + 2, num_reflections=num_reflections
308
+ ),
309
+ client=client,
310
+ model=model,
311
+ system_message=idea_system_prompt,
312
+ msg_history=msg_history,
313
+ )
314
+ ## Format the content in JSON if using weak LLM
315
+ text = format_idea_json(text)
316
+ ## PARSE OUTPUT
317
+ json_output = extract_json_between_markers(text)
318
+ assert (
319
+ json_output is not None
320
+ ), "Failed to extract JSON from LLM output"
321
+ # print(json_output)
322
+
323
+ if "I am done" in text:
324
+ print(
325
+ f"Idea generation converged after {j + 2} iterations."
326
+ )
327
+ break
328
+
329
+ idea_archive.append(json_output)
330
+ break
331
+ except Exception as e:
332
+ print(f"Failed to generate idea: {e}")
333
+ traceback.print_exc()
334
+ continue
335
+
336
+ ## SAVE IDEAS
337
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
338
+ json.dump(idea_archive, f, indent=4)
339
+
340
+ return idea_archive
341
+
342
+
343
+ def on_backoff(details):
344
+ print(
345
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
346
+ f"calling function {details['target'].__name__} at {time.strftime('%X')}"
347
+ )
348
+
349
+
350
+ @backoff.on_exception(
351
+ backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
352
+ )
353
+ def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
354
+ if not query:
355
+ return None
356
+ rsp = requests.get(
357
+ "https://api.semanticscholar.org/graph/v1/paper/search",
358
+ headers={"X-API-KEY": S2_API_KEY},
359
+ params={
360
+ "query": query,
361
+ "limit": result_limit,
362
+ "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
363
+ },
364
+ )
365
+ print(f"Response Status Code: {rsp.status_code}")
366
+ print(
367
+ f"Response Content: {rsp.text[:500]}"
368
+ ) # Print the first 500 characters of the response content
369
+ rsp.raise_for_status()
370
+ results = rsp.json()
371
+ total = results["total"]
372
+ if not total:
373
+ return None
374
+ time.sleep(2)
375
+ papers = results["data"]
376
+ return papers
377
+
378
+
379
+ novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
380
+ You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
381
+ Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
382
+ You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
383
+ The top 10 results for any search query will be presented to you with the abstracts.
384
+
385
+ You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
386
+ At any round, you may exit early and decide on the novelty of the idea.
387
+ Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
388
+ Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
389
+
390
+ {task_description}
391
+ <experiment.py>
392
+ {code}
393
+ </experiment.py>
394
+ """
395
+
396
+ novelty_prompt = '''Round {current_round}/{num_rounds}.
397
+ You have this idea:
398
+
399
+ """
400
+ {idea}
401
+ """
402
+
403
+ The results of the last query are (empty on first round):
404
+ """
405
+ {last_query_results}
406
+ """
407
+
408
+ Respond in the following format:
409
+
410
+ THOUGHT:
411
+ <THOUGHT>
412
+
413
+ RESPONSE:
414
+ ```json
415
+ <JSON>
416
+ ```
417
+
418
+ In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision.
419
+ If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
420
+
421
+ In <JSON>, respond in JSON format with ONLY the following field:
422
+ - "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
423
+
424
+ A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
425
+ This JSON will be automatically parsed, so ensure the format is precise.
426
+ '''
427
+
428
+
429
+ def check_idea_novelty(
430
+ ideas,
431
+ base_dir,
432
+ client,
433
+ model,
434
+ max_num_iterations=10,
435
+ ):
436
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
437
+ code = f.read()
438
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
439
+ prompt = json.load(f)
440
+ task_description = prompt["task_description"]
441
+
442
+ for idx, idea in enumerate(ideas):
443
+ if "novel" in idea:
444
+ print(f"Skipping idea {idx}, already checked.")
445
+ continue
446
+
447
+ print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
448
+
449
+ novel = False
450
+ msg_history = []
451
+ papers_str = ""
452
+
453
+ for j in range(max_num_iterations):
454
+ try:
455
+ text, msg_history = get_response_from_llm(
456
+ novelty_prompt.format(
457
+ current_round=j + 1,
458
+ num_rounds=max_num_iterations,
459
+ idea=idea,
460
+ last_query_results=papers_str,
461
+ ),
462
+ client=client,
463
+ model=model,
464
+ system_message=novelty_system_msg.format(
465
+ num_rounds=max_num_iterations,
466
+ task_description=task_description,
467
+ code=code,
468
+ ),
469
+ msg_history=msg_history,
470
+ )
471
+ if "decision made: novel" in text.lower():
472
+ print("Decision made: novel after round", j)
473
+ novel = True
474
+ break
475
+ if "decision made: not novel" in text.lower():
476
+ print("Decision made: not novel after round", j)
477
+ break
478
+
479
+ ## Format the content in JSON
480
+ text = format_novelty_json(text)
481
+ print("text after formating\n", text)
482
+ ## PARSE OUTPUT
483
+ json_output = extract_json_between_markers(text)
484
+ assert json_output is not None, "Failed to extract JSON from LLM output"
485
+
486
+ ## SEARCH FOR PAPERS
487
+ query = json_output["Query"]
488
+ papers = search_for_papers(query, result_limit=10)
489
+ if papers is None:
490
+ papers_str = "No papers found."
491
+
492
+ paper_strings = []
493
+ for i, paper in enumerate(papers):
494
+ paper_strings.append(
495
+ """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
496
+ i=i,
497
+ title=paper["title"],
498
+ authors=paper["authors"],
499
+ venue=paper["venue"],
500
+ year=paper["year"],
501
+ cites=paper["citationCount"],
502
+ abstract=paper["abstract"],
503
+ )
504
+ )
505
+ papers_str = "\n\n".join(paper_strings)
506
+
507
+ except Exception as e:
508
+ print(f"Error: {e}")
509
+ continue
510
+
511
+ idea["novel"] = novel
512
+
513
+ # Save results to JSON file
514
+ results_file = osp.join(base_dir, "ideas.json")
515
+ with open(results_file, "w") as f:
516
+ json.dump(ideas, f, indent=4)
517
+
518
+ return ideas
519
+
520
+
521
+ if __name__ == "__main__":
522
+ MAX_NUM_GENERATIONS = 32
523
+ NUM_REFLECTIONS = 5
524
+ import argparse
525
+
526
+ parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
527
+ # add type of experiment (nanoGPT, Boston, etc.)
528
+ parser.add_argument(
529
+ "--experiment",
530
+ type=str,
531
+ default="nanoGPT",
532
+ help="Experiment to run AI Scientist on.",
533
+ )
534
+ parser.add_argument(
535
+ "--model",
536
+ type=str,
537
+ default="deepseek-ai/DeepSeek-V2.5",
538
+ choices=allchoices,
539
+ help="Model to use for AI Scientist.",
540
+ )
541
+ parser.add_argument(
542
+ "--skip-idea-generation",
543
+ action="store_true",
544
+ help="Skip idea generation and use existing ideas.",
545
+ )
546
+ parser.add_argument(
547
+ "--check-novelty",
548
+ action="store_true",
549
+ help="Check novelty of ideas.",
550
+ )
551
+ args = parser.parse_args()
552
+
553
+ # Create client
554
+
555
+ # ------------------------------------------------------------------------------------------------------
556
+
557
+ if args.model == "Qwen/Qwen2.5-72B-Instruct":
558
+ # elif args.model.startswith("hyperbolic"):
559
+ print(f"Welcome to the PARADISE of debug <generate_scientist.py> {args.model}.")
560
+
561
+ import openai
562
+ import os
563
+ # client_model = args.model[11:]
564
+ client_model = args.model
565
+ client = openai.OpenAI(
566
+ api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
567
+ )
568
+
569
+ # ------------------------------------------------------------------------------------------------------
570
+
571
+
572
+ elif args.model == "claude-3-5-sonnet-20240620":
573
+ import anthropic
574
+
575
+ print(f"Using Anthropic API with model {args.model}.")
576
+ client_model = "claude-3-5-sonnet-20240620"
577
+ client = anthropic.Anthropic()
578
+ elif args.model.startswith("bedrock") and "claude" in args.model:
579
+ import anthropic
580
+
581
+ # Expects: bedrock/<MODEL_ID>
582
+ client_model = args.model.split("/")[-1]
583
+
584
+ print(f"Using Amazon Bedrock with model {client_model}.")
585
+ client = anthropic.AnthropicBedrock()
586
+ elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
587
+ import openai
588
+
589
+ print(f"Using OpenAI API with model {args.model}.")
590
+ client_model = "gpt-4o-2024-05-13"
591
+ client = openai.OpenAI()
592
+ elif args.model == "deepseek-coder-v2-0724":
593
+ import openai
594
+
595
+ print(f"Using OpenAI API with {args.model}.")
596
+ client_model = "deepseek-coder-v2-0724"
597
+ client = openai.OpenAI(
598
+ api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
599
+ )
600
+ elif args.model == "llama3.1-405b":
601
+ import openai
602
+
603
+ print(f"Using OpenAI API with {args.model}.")
604
+ client_model = "meta-llama/llama-3.1-405b-instruct"
605
+ client = openai.OpenAI(
606
+ api_key=os.environ["OPENROUTER_API_KEY"],
607
+ base_url="https://openrouter.ai/api/v1",
608
+ )
609
+ elif args.model.startswith("ollama"):
610
+ import openai
611
+
612
+ print(f"Using Ollama with {args.model}.")
613
+ client_model = args.model.split("/")[-1]
614
+ # client_model = args.model
615
+ client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
616
+
617
+ else:
618
+ raise ValueError(f"Model {args.model} not supported.")
619
+
620
+ base_dir = osp.join("templates", args.experiment)
621
+ results_dir = osp.join("results", args.experiment)
622
+ print("going into line 623...")
623
+ ideas = generate_ideas(
624
+ base_dir,
625
+ client=client,
626
+ model=client_model,
627
+ skip_generation=args.skip_idea_generation,
628
+ max_num_generations=MAX_NUM_GENERATIONS,
629
+ num_reflections=NUM_REFLECTIONS,
630
+ )
631
+ if args.check_novelty:
632
+ ideas = check_idea_novelty(
633
+ ideas,
634
+ base_dir=base_dir,
635
+ client=client,
636
+ model=client_model,
637
+ )
ai_scientist/llm.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import backoff
4
+ import openai
5
+
6
+ # Ollama
7
+ ollama_choices = [
8
+ "mistral-nemo",
9
+ "llama3.1",
10
+ "qwen2.5:32b"
11
+ ]
12
+
13
+ # hyperbolic
14
+ hyperbolic_choices = [
15
+ "Qwen/Qwen2.5-72B-Instruct",
16
+ "meta-llama/Meta-Llama-3.1-70B-Instruct",
17
+ ]
18
+
19
+
20
+ allchoices = [
21
+ "Qwen/Qwen2.5-72B-Instruct",
22
+ "deepseek-ai/DeepSeek-V2.5",
23
+ "claude-3-5-sonnet-20240620",
24
+ "gpt-4o-2024-05-13",
25
+ "deepseek-coder-v2-0724",
26
+ "llama3.1-405b",
27
+ # Anthropic Claude models via Amazon Bedrock
28
+ "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
29
+ "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
30
+ "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
31
+ "bedrock/anthropic.claude-3-opus-20240229-v1:0",
32
+ ]
33
+
34
+ for item in ollama_choices:
35
+ allchoices.append("ollama/" + item)
36
+
37
+ for item in hyperbolic_choices:
38
+ allchoices.append("hyperbolic/" + item)
39
+
40
+
41
+ # Get N responses from a single message, used for ensembling.
42
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
43
+ def get_batch_responses_from_llm(
44
+ msg,
45
+ client,
46
+ model,
47
+ system_message,
48
+ print_debug=False,
49
+ msg_history=None,
50
+ temperature=0.75,
51
+ n_responses=1,
52
+ ):
53
+ if msg_history is None:
54
+ msg_history = []
55
+
56
+ if model in [
57
+ "gpt-4o-2024-05-13",
58
+ "gpt-4o-mini-2024-07-18",
59
+ "gpt-4o-2024-08-06",
60
+ "Qwen/Qwen2.5-72B-Instruct"
61
+ ]:
62
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
63
+ response = client.chat.completions.create(
64
+ model=model,
65
+ messages=[
66
+ {"role": "system", "content": system_message},
67
+ *new_msg_history,
68
+ ],
69
+ temperature=temperature,
70
+ max_tokens=3000,
71
+ n=n_responses,
72
+ stop=None,
73
+ seed=0,
74
+ )
75
+ content = [r.message.content for r in response.choices]
76
+ new_msg_history = [
77
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
78
+ ]
79
+ elif model == "deepseek-coder-v2-0724":
80
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
81
+ response = client.chat.completions.create(
82
+ model="deepseek-coder",
83
+ messages=[
84
+ {"role": "system", "content": system_message},
85
+ *new_msg_history,
86
+ ],
87
+ temperature=temperature,
88
+ max_tokens=3000,
89
+ n=n_responses,
90
+ stop=None,
91
+ )
92
+ content = [r.message.content for r in response.choices]
93
+ new_msg_history = [
94
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
95
+ ]
96
+
97
+ # ------------------------------------------------------------------------------------------------------
98
+
99
+ elif model == "Qwen/Qwen2.5-72B-Instruct":
100
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
101
+ response = client.chat.completions.create(
102
+ model="Qwen/Qwen2.5-72B-Instruct",
103
+ messages=[
104
+ {"role": "system", "content": system_message},
105
+ *new_msg_history,
106
+ ],
107
+ temperature=temperature,
108
+ max_tokens=3000,
109
+ n=n_responses,
110
+ stop=None,
111
+ )
112
+ content = [r.message.content for r in response.choices]
113
+ new_msg_history = [
114
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
115
+ ]
116
+
117
+ # elif model in hyperbolic_choices:
118
+ # content, new_msg_history = [], []
119
+ # for i in range(n_responses):
120
+ # print(f"Getting {i+1}/{n_responses} response from {model}")
121
+ # c, hist = get_response_from_llm(
122
+ # msg,
123
+ # client,
124
+ # model,
125
+ # system_message,
126
+ # print_debug=False,
127
+ # msg_history=None,
128
+ # temperature=temperature,
129
+ # )
130
+ # content.append(c)
131
+ # new_msg_history.append(hist)
132
+
133
+ # ------------------------------------------------------------------------------------------------------
134
+
135
+ elif model == "llama-3-1-405b-instruct":
136
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
137
+ response = client.chat.completions.create(
138
+ model="meta-llama/llama-3.1-405b-instruct",
139
+ messages=[
140
+ {"role": "system", "content": system_message},
141
+ *new_msg_history,
142
+ ],
143
+ temperature=temperature,
144
+ max_tokens=3000,
145
+ n=n_responses,
146
+ stop=None,
147
+ )
148
+ content = [r.message.content for r in response.choices]
149
+ new_msg_history = [
150
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
151
+ ]
152
+ elif model == "claude-3-5-sonnet-20240620":
153
+ content, new_msg_history = [], []
154
+ for _ in range(n_responses):
155
+ c, hist = get_response_from_llm(
156
+ msg,
157
+ client,
158
+ model,
159
+ system_message,
160
+ print_debug=False,
161
+ msg_history=None,
162
+ temperature=temperature,
163
+ )
164
+ content.append(c)
165
+ new_msg_history.append(hist)
166
+
167
+ # ollama models
168
+ elif model in ollama_choices:
169
+ content, new_msg_history = [], []
170
+ for i in range(n_responses):
171
+ print(f"Getting {i+1}/{n_responses} response from {model}")
172
+ c, hist = get_response_from_llm(
173
+ msg,
174
+ client,
175
+ model,
176
+ system_message,
177
+ print_debug=False,
178
+ msg_history=None,
179
+ temperature=temperature,
180
+ )
181
+ content.append(c)
182
+ new_msg_history.append(hist)
183
+ else:
184
+ raise ValueError(f"Model {model} not supported.")
185
+
186
+ if print_debug:
187
+ # Just print the first one.
188
+ print()
189
+ print("*" * 20 + " LLM START " + "*" * 20)
190
+ for j, msg in enumerate(new_msg_history[0]):
191
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
192
+ print(content)
193
+ print("*" * 21 + " LLM END " + "*" * 21)
194
+ print()
195
+
196
+ return content, new_msg_history
197
+
198
+
199
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
200
+ def get_response_from_llm(
201
+ msg,
202
+ client,
203
+ model,
204
+ system_message,
205
+ print_debug=False,
206
+ msg_history=None,
207
+ temperature=0.75,
208
+ ):
209
+ if msg_history is None:
210
+ msg_history = []
211
+
212
+ if model == "claude-3-5-sonnet-20240620":
213
+ new_msg_history = msg_history + [
214
+ {
215
+ "role": "user",
216
+ "content": [
217
+ {
218
+ "type": "text",
219
+ "text": msg,
220
+ }
221
+ ],
222
+ }
223
+ ]
224
+ response = client.messages.create(
225
+ model="claude-3-5-sonnet-20240620",
226
+ max_tokens=3000,
227
+ temperature=temperature,
228
+ system=system_message,
229
+ messages=new_msg_history,
230
+ )
231
+ content = response.content[0].text
232
+ new_msg_history = new_msg_history + [
233
+ {
234
+ "role": "assistant",
235
+ "content": [
236
+ {
237
+ "type": "text",
238
+ "text": content,
239
+ }
240
+ ],
241
+ }
242
+ ]
243
+ # ------------------------------------------------------------------------------------------------------
244
+
245
+ elif model in [
246
+ "gpt-4o-2024-05-13",
247
+ "gpt-4o-mini-2024-07-18",
248
+ "gpt-4o-2024-08-06",
249
+ "Qwen/Qwen2.5-72B-Instruct"
250
+ ]:
251
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
252
+ response = client.chat.completions.create(
253
+ model=model,
254
+ messages=[
255
+ {"role": "system", "content": system_message},
256
+ *new_msg_history,
257
+ ],
258
+ temperature=temperature,
259
+ max_tokens=3000,
260
+ n=1,
261
+ stop=None,
262
+ seed=0,
263
+ )
264
+ content = response.choices[0].message.content
265
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
266
+
267
+
268
+ # ------------------------------------------------------------------------------------------------------
269
+
270
+
271
+ elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
272
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
273
+ response = client.chat.completions.create(
274
+ model="meta-llama/llama-3.1-405b-instruct",
275
+ messages=[
276
+ {"role": "system", "content": system_message},
277
+ *new_msg_history,
278
+ ],
279
+ temperature=temperature,
280
+ max_tokens=3000,
281
+ n=1,
282
+ stop=None,
283
+ )
284
+ content = response.choices[0].message.content
285
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
286
+
287
+
288
+ elif model in ollama_choices:
289
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
290
+ response = client.chat.completions.create(
291
+ model=model,
292
+ messages=[
293
+ {"role": "system", "content": system_message},
294
+ *new_msg_history,
295
+ ],
296
+ temperature=temperature,
297
+ max_tokens=6000,
298
+ n=1,
299
+ stop=None,
300
+ seed=0,
301
+ )
302
+ content = response.choices[0].message.content
303
+ # print("\nget_response_from_llm\n")
304
+ # print(content)
305
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
306
+
307
+ else:
308
+ raise ValueError(f"Model {model} not supported.")
309
+
310
+ if print_debug:
311
+ print()
312
+ print("*" * 20 + " LLM START " + "*" * 20)
313
+ for j, msg in enumerate(new_msg_history):
314
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
315
+ print(content)
316
+ print("*" * 21 + " LLM END " + "*" * 21)
317
+ print()
318
+
319
+ return content, new_msg_history
320
+
321
+
322
+ def llm_json_auto_correct(system_prompt: str, user_prompt: str) -> str:
323
+ import os
324
+ client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1")
325
+ response = client.chat.completions.create(
326
+ model="Qwen/Qwen2.5-72B-Instruct",
327
+ temperature=0,
328
+ messages=[
329
+ {"role": "system", "content": system_prompt},
330
+ {"role": "user", "content": user_prompt},
331
+ ],
332
+ )
333
+ return response.choices[0].message.content
334
+
335
+
336
+ def extract_json_between_markers(llm_output):
337
+ json_start_marker = "```json"
338
+ json_end_marker = "```"
339
+
340
+ # Find the start and end indices of the JSON string
341
+ start_index = llm_output.find(json_start_marker)
342
+ if start_index != -1:
343
+ start_index += len(json_start_marker) # Move past the marker
344
+ end_index = llm_output.find(json_end_marker, start_index)
345
+ else:
346
+ return None # JSON markers not found
347
+
348
+ if end_index == -1:
349
+ return None # End marker not found
350
+
351
+ # Extract the JSON string
352
+ json_string = llm_output[start_index:end_index].strip()
353
+ # print(json_string)
354
+ try:
355
+ parsed_json = json.loads(json_string)
356
+
357
+ return parsed_json
358
+ except json.JSONDecodeError:
359
+ return None # Invalid JSON format
ai_scientist/perform_experiments.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os.path as osp
3
+ import subprocess
4
+ from subprocess import TimeoutExpired
5
+ import sys
6
+ import json
7
+
8
+ MAX_ITERS = 4
9
+ MAX_RUNS = 5
10
+ MAX_STDERR_OUTPUT = 1500
11
+
12
+ coder_prompt = """Your goal is to implement the following idea: {title}.
13
+ The proposed experiment is as follows: {idea}.
14
+ You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
15
+
16
+ First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
17
+
18
+ Note that we already provide the vanilla baseline results, so you do not need to re-run it.
19
+
20
+ For reference, the baseline results are as follows:
21
+
22
+ {baseline_results}
23
+
24
+ After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
25
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
26
+ You can then implement the next thing on your list."""
27
+
28
+
29
+ # RUN EXPERIMENT
30
+ def run_experiment(folder_name, run_num, timeout=7200):
31
+ cwd = osp.abspath(folder_name)
32
+ # COPY CODE SO WE CAN SEE IT.
33
+ shutil.copy(
34
+ osp.join(folder_name, "experiment.py"),
35
+ osp.join(folder_name, f"run_{run_num}.py"),
36
+ )
37
+
38
+ # LAUNCH COMMAND
39
+ command = [
40
+ "python",
41
+ "experiment.py",
42
+ f"--out_dir=run_{run_num}",
43
+ ]
44
+ try:
45
+ result = subprocess.run(
46
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
47
+ )
48
+
49
+ if result.stderr:
50
+ print(result.stderr, file=sys.stderr)
51
+
52
+ if result.returncode != 0:
53
+ print(f"Run {run_num} failed with return code {result.returncode}")
54
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
55
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
56
+ print(f"Run failed with the following error {result.stderr}")
57
+ stderr_output = result.stderr
58
+ if len(stderr_output) > MAX_STDERR_OUTPUT:
59
+ stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
60
+ next_prompt = f"Run failed with the following error {stderr_output}"
61
+ else:
62
+ with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
63
+ results = json.load(f)
64
+ results = {k: v["means"] for k, v in results.items()}
65
+
66
+ next_prompt = f"""Run {run_num} completed. Here are the results:
67
+ {results}
68
+
69
+ Decide if you need to re-plan your experiments given the result (you often will not need to).
70
+
71
+ Someone else will be using `notes.txt` to perform a writeup on this in the future.
72
+ Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
73
+
74
+ Then, implement the next thing on your list.
75
+ We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
76
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
77
+ If you are finished with experiments, respond with 'ALL_COMPLETED'."""
78
+ return result.returncode, next_prompt
79
+ except TimeoutExpired:
80
+ print(f"Run {run_num} timed out after {timeout} seconds")
81
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
82
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
83
+ next_prompt = f"Run timed out after {timeout} seconds"
84
+ return 1, next_prompt
85
+
86
+
87
+ # RUN PLOTTING
88
+ def run_plotting(folder_name, timeout=600):
89
+ cwd = osp.abspath(folder_name)
90
+ # LAUNCH COMMAND
91
+ command = [
92
+ "python",
93
+ "plot.py",
94
+ ]
95
+ try:
96
+ result = subprocess.run(
97
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
98
+ )
99
+
100
+ if result.stderr:
101
+ print(result.stderr, file=sys.stderr)
102
+
103
+ if result.returncode != 0:
104
+ print(f"Plotting failed with return code {result.returncode}")
105
+ next_prompt = f"Plotting failed with the following error {result.stderr}"
106
+ else:
107
+ next_prompt = ""
108
+ return result.returncode, next_prompt
109
+ except TimeoutExpired:
110
+ print(f"Plotting timed out after {timeout} seconds")
111
+ next_prompt = f"Plotting timed out after {timeout} seconds"
112
+ return 1, next_prompt
113
+
114
+
115
+ # PERFORM EXPERIMENTS
116
+ def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
117
+ ## RUN EXPERIMENT
118
+ current_iter = 0
119
+ run = 1
120
+ next_prompt = coder_prompt.format(
121
+ title=idea["Title"],
122
+ idea=idea["Experiment"],
123
+ max_runs=MAX_RUNS,
124
+ baseline_results=baseline_results,
125
+ )
126
+ while run < MAX_RUNS + 1:
127
+ if current_iter >= MAX_ITERS:
128
+ print("Max iterations reached")
129
+ break
130
+ coder_out = coder.run(next_prompt)
131
+ print(coder_out)
132
+ if "ALL_COMPLETED" in coder_out:
133
+ break
134
+ return_code, next_prompt = run_experiment(folder_name, run)
135
+ if return_code == 0:
136
+ run += 1
137
+ current_iter = 0
138
+ current_iter += 1
139
+ if current_iter >= MAX_ITERS:
140
+ print("Not all experiments completed.")
141
+ return False
142
+
143
+ current_iter = 0
144
+ next_prompt = """
145
+ Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
146
+
147
+ In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
148
+
149
+ Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
150
+
151
+ We will be running the command `python plot.py` to generate the plots.
152
+ """
153
+ while True:
154
+ coder_out = coder.run(next_prompt)
155
+ return_code, next_prompt = run_plotting(folder_name)
156
+ current_iter += 1
157
+ if return_code == 0 or current_iter >= MAX_ITERS:
158
+ break
159
+ next_prompt = """
160
+ Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
161
+
162
+ Somebody else will be using `notes.txt` to write a report on this in the future.
163
+ """
164
+ coder.run(next_prompt)
165
+
166
+ return True
ai_scientist/perform_review.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import numpy as np
5
+ import pymupdf
6
+ import pymupdf4llm
7
+ from pypdf import PdfReader
8
+ from strictjson import strict_json
9
+
10
+ from ai_scientist.llm import (
11
+ extract_json_between_markers,
12
+ get_batch_responses_from_llm,
13
+ get_response_from_llm,
14
+ llm_json_auto_correct,
15
+ )
16
+
17
+
18
+ # Format the content in JSON
19
+ def format_llm_review_json(text):
20
+ res = strict_json(
21
+ system_prompt="You are a JSON formatter",
22
+ user_prompt=text,
23
+ return_as_json=True,
24
+ output_format={
25
+ "Summary": "A summary of the paper content and its contributions.",
26
+ "Strengths": "A list of strengths of the paper, type: list",
27
+ "Weaknesses": "A list of weaknesses of the paper, type: list",
28
+ "Originality": "A rating from 1 to 4 (low, medium, high, very high), type: int",
29
+ "Quality": "A rating from 1 to 4 (low, medium, high, very high), type: int",
30
+ "Clarity": "A rating from 1 to 4 (low, medium, high, very high), type: int",
31
+ "Significance": "A rating from 1 to 4 (low, medium, high, very high), type: int",
32
+ "Questions": "A set of clarifying questions to be answered by the paper authors, type: list",
33
+ "Limitations": "A set of limitations and potential negative societal impacts of the work, type: str",
34
+ "Ethical Concerns": "A boolean value indicating whether there are ethical concerns, type: bool",
35
+ "Soundness": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
36
+ "Presentation": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
37
+ "Contribution": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
38
+ "Overall": "A rating from 1 to 10 (very strong reject to award quality), type: int",
39
+ "Confidence": "A rating from 1 to 5 (low, medium, high, very high, absolute), type: int",
40
+ "Decision": "A decision that has to be Accept or Reject, type: str",
41
+ },
42
+ llm=llm_json_auto_correct,
43
+ )
44
+ text = json.loads(res)
45
+ return text
46
+
47
+
48
+ reviewer_system_prompt_base = (
49
+ "You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue."
50
+ "Be critical and cautious in your decision."
51
+ )
52
+
53
+ reviewer_system_prompt_neg = (
54
+ reviewer_system_prompt_base
55
+ + "If a paper is bad or you are unsure, give it bad scores and reject it."
56
+ )
57
+ reviewer_system_prompt_pos = (
58
+ reviewer_system_prompt_base
59
+ + "If a paper is good or you are unsure, give it good scores and accept it."
60
+ )
61
+
62
+ template_instructions = """
63
+ Respond in the following format:
64
+
65
+ THOUGHT:
66
+ <THOUGHT>
67
+
68
+ REVIEW JSON:
69
+ ```json
70
+ <JSON>
71
+ ```
72
+
73
+ In <THOUGHT>, first briefly discuss your intuitions and reasoning for the evaluation.
74
+ Detail your high-level arguments, necessary choices and desired outcomes of the review.
75
+ Do not make generic comments here, but be specific to your current paper.
76
+ Treat this as the note-taking phase of your review.
77
+
78
+ In <JSON>, provide the review in JSON format with the following fields in the order:
79
+ - "Summary": A summary of the paper content and its contributions.
80
+ - "Strengths": A list of strengths of the paper.
81
+ - "Weaknesses": A list of weaknesses of the paper.
82
+ - "Originality": A rating from 1 to 4 (low, medium, high, very high).
83
+ - "Quality": A rating from 1 to 4 (low, medium, high, very high).
84
+ - "Clarity": A rating from 1 to 4 (low, medium, high, very high).
85
+ - "Significance": A rating from 1 to 4 (low, medium, high, very high).
86
+ - "Questions": A set of clarifying questions to be answered by the paper authors.
87
+ - "Limitations": A set of limitations and potential negative societal impacts of the work.
88
+ - "Ethical Concerns": A boolean value indicating whether there are ethical concerns.
89
+ - "Soundness": A rating from 1 to 4 (poor, fair, good, excellent).
90
+ - "Presentation": A rating from 1 to 4 (poor, fair, good, excellent).
91
+ - "Contribution": A rating from 1 to 4 (poor, fair, good, excellent).
92
+ - "Overall": A rating from 1 to 10 (very strong reject to award quality).
93
+ - "Confidence": A rating from 1 to 5 (low, medium, high, very high, absolute).
94
+ - "Decision": A decision that has to be one of the following: Accept, Reject.
95
+
96
+ For the "Decision" field, don't use Weak Accept, Borderline Accept, Borderline Reject, or Strong Reject. Instead, only use Accept or Reject.
97
+ This JSON will be automatically parsed, so ensure the format is precise.
98
+ """
99
+
100
+ neurips_form = (
101
+ """
102
+ ## Review Form
103
+ Below is a description of the questions you will be asked on the review form for each paper and some guidelines on what to consider when answering these questions.
104
+ When writing your review, please keep in mind that after decisions have been made, reviews and meta-reviews of accepted papers and opted-in rejected papers will be made public.
105
+
106
+ 1. Summary: Briefly summarize the paper and its contributions. This is not the place to critique the paper; the authors should generally agree with a well-written summary.
107
+ - Strengths and Weaknesses: Please provide a thorough assessment of the strengths and weaknesses of the paper, touching on each of the following dimensions:
108
+ - Originality: Are the tasks or methods new? Is the work a novel combination of well-known techniques? (This can be valuable!) Is it clear how this work differs from previous contributions? Is related work adequately cited
109
+ - Quality: Is the submission technically sound? Are claims well supported (e.g., by theoretical analysis or experimental results)? Are the methods used appropriate? Is this a complete piece of work or work in progress? Are the authors careful and honest about evaluating both the strengths and weaknesses of their work
110
+ - Clarity: Is the submission clearly written? Is it well organized? (If not, please make constructive suggestions for improving its clarity.) Does it adequately inform the reader? (Note that a superbly written paper provides enough information for an expert reader to reproduce its results.)
111
+ - Significance: Are the results important? Are others (researchers or practitioners) likely to use the ideas or build on them? Does the submission address a difficult task in a better way than previous work? Does it advance the state of the art in a demonstrable way? Does it provide unique data, unique conclusions about existing data, or a unique theoretical or experimental approach?
112
+
113
+ 2. Questions: Please list up and carefully describe any questions and suggestions for the authors. Think of the things where a response from the author can change your opinion, clarify a confusion or address a limitation. This can be very important for a productive rebuttal and discussion phase with the authors.
114
+
115
+ 3. Limitations: Have the authors adequately addressed the limitations and potential negative societal impact of their work? If not, please include constructive suggestions for improvement.
116
+ In general, authors should be rewarded rather than punished for being up front about the limitations of their work and any potential negative societal impact. You are encouraged to think through whether any critical points are missing and provide these as feedback for the authors.
117
+
118
+ 4. Ethical concerns: If there are ethical issues with this paper, please flag the paper for an ethics review. For guidance on when this is appropriate, please review the NeurIPS ethics guidelines.
119
+
120
+ 5. Soundness: Please assign the paper a numerical rating on the following scale to indicate the soundness of the technical claims, experimental and research methodology and on whether the central claims of the paper are adequately supported with evidence.
121
+ 4: excellent
122
+ 3: good
123
+ 2: fair
124
+ 1: poor
125
+
126
+ 6. Presentation: Please assign the paper a numerical rating on the following scale to indicate the quality of the presentation. This should take into account the writing style and clarity, as well as contextualization relative to prior work.
127
+ 4: excellent
128
+ 3: good
129
+ 2: fair
130
+ 1: poor
131
+
132
+ 7. Contribution: Please assign the paper a numerical rating on the following scale to indicate the quality of the overall contribution this paper makes to the research area being studied. Are the questions being asked important? Does the paper bring a significant originality of ideas and/or execution? Are the results valuable to share with the broader NeurIPS community.
133
+ 4: excellent
134
+ 3: good
135
+ 2: fair
136
+ 1: poor
137
+
138
+ 8. Overall: Please provide an "overall score" for this submission. Choices:
139
+ 10: Award quality: Technically flawless paper with groundbreaking impact on one or more areas of AI, with exceptionally strong evaluation, reproducibility, and resources, and no unaddressed ethical considerations.
140
+ 9: Very Strong Accept: Technically flawless paper with groundbreaking impact on at least one area of AI and excellent impact on multiple areas of AI, with flawless evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
141
+ 8: Strong Accept: Technically strong paper with, with novel ideas, excellent impact on at least one area of AI or high-to-excellent impact on multiple areas of AI, with excellent evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
142
+ 7: Accept: Technically solid paper, with high impact on at least one sub-area of AI or moderate-to-high impact on more than one area of AI, with good-to-excellent evaluation, resources, reproducibility, and no unaddressed ethical considerations.
143
+ 6: Weak Accept: Technically solid, moderate-to-high impact paper, with no major concerns with respect to evaluation, resources, reproducibility, ethical considerations.
144
+ 5: Borderline accept: Technically solid paper where reasons to accept outweigh reasons to reject, e.g., limited evaluation. Please use sparingly.
145
+ 4: Borderline reject: Technically solid paper where reasons to reject, e.g., limited evaluation, outweigh reasons to accept, e.g., good evaluation. Please use sparingly.
146
+ 3: Reject: For instance, a paper with technical flaws, weak evaluation, inadequate reproducibility and incompletely addressed ethical considerations.
147
+ 2: Strong Reject: For instance, a paper with major technical flaws, and/or poor evaluation, limited impact, poor reproducibility and mostly unaddressed ethical considerations.
148
+ 1: Very Strong Reject: For instance, a paper with trivial results or unaddressed ethical considerations
149
+
150
+ 9. Confidence: Please provide a "confidence score" for your assessment of this submission to indicate how confident you are in your evaluation. Choices:
151
+ 5: You are absolutely certain about your assessment. You are very familiar with the related work and checked the math/other details carefully.
152
+ 4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work.
153
+ 3: You are fairly confident in your assessment. It is possible that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
154
+ 2: You are willing to defend your assessment, but it is quite likely that you did not understand the central parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
155
+ 1: Your assessment is an educated guess. The submission is not in your area or the submission was difficult to understand. Math/other details were not carefully checked.
156
+ """
157
+ + template_instructions
158
+ )
159
+
160
+
161
+ def perform_review(
162
+ text,
163
+ model,
164
+ client,
165
+ num_reflections=1,
166
+ num_fs_examples=1,
167
+ num_reviews_ensemble=1,
168
+ temperature=0.75,
169
+ msg_history=None,
170
+ return_msg_history=False,
171
+ reviewer_system_prompt=reviewer_system_prompt_neg,
172
+ review_instruction_form=neurips_form,
173
+ ):
174
+ if num_fs_examples > 0:
175
+ fs_prompt = get_review_fewshot_examples(num_fs_examples)
176
+ base_prompt = review_instruction_form + fs_prompt
177
+ else:
178
+ base_prompt = review_instruction_form
179
+
180
+ base_prompt += f"""
181
+ Here is the paper you are asked to review:
182
+ ```
183
+ {text}
184
+ ```"""
185
+
186
+ if num_reviews_ensemble > 1:
187
+ llm_review, msg_histories = get_batch_responses_from_llm(
188
+ base_prompt,
189
+ model=model,
190
+ client=client,
191
+ system_message=reviewer_system_prompt,
192
+ print_debug=False,
193
+ msg_history=msg_history,
194
+ # Higher temperature to encourage diversity.
195
+ temperature=0.75,
196
+ n_responses=num_reviews_ensemble,
197
+ )
198
+ parsed_reviews = []
199
+ for idx, rev in enumerate(llm_review):
200
+ try:
201
+ parsed_reviews.append(format_llm_review_json(rev))
202
+ except Exception as e:
203
+ print(f"Ensemble review {idx} failed: {e}")
204
+ parsed_reviews = [r for r in parsed_reviews if r is not None]
205
+ review = get_meta_review(model, client, temperature, parsed_reviews)
206
+ ## Format the content in JSON
207
+ review = format_llm_review_json(review)
208
+ # take first valid in case meta-reviewer fails
209
+ if review is None:
210
+ review = parsed_reviews[0]
211
+ # print(parsed_reviews, "\n\n\n", review) # debug
212
+ # Replace numerical scores with the average of the ensemble.
213
+ for score, limits in [
214
+ ("Originality", (1, 4)),
215
+ ("Quality", (1, 4)),
216
+ ("Clarity", (1, 4)),
217
+ ("Significance", (1, 4)),
218
+ ("Soundness", (1, 4)),
219
+ ("Presentation", (1, 4)),
220
+ ("Contribution", (1, 4)),
221
+ ("Overall", (1, 10)),
222
+ ("Confidence", (1, 5)),
223
+ ]:
224
+ scores = []
225
+ for r in parsed_reviews:
226
+ if score in r and limits[1] >= r[score] >= limits[0]:
227
+ scores.append(r[score])
228
+ review[score] = int(round(np.mean(scores)))
229
+
230
+ # Rewrite the message history with the valid one and new aggregated review.
231
+ msg_history = msg_histories[0][:-1]
232
+ msg_history += [
233
+ {
234
+ "role": "assistant",
235
+ "content": f"""
236
+ THOUGHT:
237
+ I will start by aggregating the opinions of {num_reviews_ensemble} reviewers that I previously obtained.
238
+
239
+ REVIEW JSON:
240
+ ```json
241
+ {json.dumps(review)}
242
+ ```
243
+ """,
244
+ }
245
+ ]
246
+ else:
247
+ llm_review, msg_history = get_response_from_llm(
248
+ base_prompt,
249
+ model=model,
250
+ client=client,
251
+ system_message=reviewer_system_prompt,
252
+ print_debug=False,
253
+ msg_history=msg_history,
254
+ temperature=temperature,
255
+ )
256
+ review = format_llm_review_json(llm_review)
257
+
258
+ if num_reflections > 1:
259
+ for j in range(num_reflections - 1):
260
+ # print(f"Relection: {j + 2}/{num_reflections}")
261
+ text, msg_history = get_response_from_llm(
262
+ reviewer_reflection_prompt,
263
+ client=client,
264
+ model=model,
265
+ system_message=reviewer_system_prompt,
266
+ msg_history=msg_history,
267
+ temperature=temperature,
268
+ )
269
+ review = format_llm_review_json(text)
270
+ assert review is not None, "Failed to extract JSON from LLM output"
271
+
272
+ if "I am done" in text:
273
+ # print(f"Review generation converged after {j + 2} iterations.")
274
+ break
275
+
276
+ if return_msg_history:
277
+ return review, msg_history
278
+ else:
279
+ return review
280
+
281
+
282
+ reviewer_reflection_prompt = """Round {current_round}/{num_reflections}.
283
+ In your thoughts, first carefully consider the accuracy and soundness of the review you just created.
284
+ Include any other factors that you think are important in evaluating the paper.
285
+ Ensure the review is clear and concise, and the JSON is in the correct format.
286
+ Do not make things overly complicated.
287
+ In the next attempt, try and refine and improve your review.
288
+ Stick to the spirit of the original review unless there are glaring issues.
289
+
290
+ Respond in the same format as before:
291
+ THOUGHT:
292
+ <THOUGHT>
293
+
294
+ REVIEW JSON:
295
+ ```json
296
+ <JSON>
297
+ ```
298
+
299
+ If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
300
+ ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
301
+
302
+
303
+ def load_paper(pdf_path, num_pages=None, min_size=100):
304
+ try:
305
+ if num_pages is None:
306
+ text = pymupdf4llm.to_markdown(pdf_path)
307
+ else:
308
+ reader = PdfReader(pdf_path)
309
+ min_pages = min(len(reader.pages), num_pages)
310
+ text = pymupdf4llm.to_markdown(pdf_path, pages=list(range(min_pages)))
311
+ if len(text) < min_size:
312
+ raise Exception("Text too short")
313
+ except Exception as e:
314
+ print(f"Error with pymupdf4llm, falling back to pymupdf: {e}")
315
+ try:
316
+ doc = pymupdf.open(pdf_path) # open a document
317
+ if num_pages:
318
+ doc = doc[:num_pages]
319
+ text = ""
320
+ for page in doc: # iterate the document pages
321
+ text = text + page.get_text() # get plain text encoded as UTF-8
322
+ if len(text) < min_size:
323
+ raise Exception("Text too short")
324
+ except Exception as e:
325
+ print(f"Error with pymupdf, falling back to pypdf: {e}")
326
+ reader = PdfReader(pdf_path)
327
+ if num_pages is None:
328
+ text = "".join(page.extract_text() for page in reader.pages)
329
+ else:
330
+ text = "".join(page.extract_text() for page in reader.pages[:num_pages])
331
+ if len(text) < min_size:
332
+ raise Exception("Text too short")
333
+
334
+ return text
335
+
336
+
337
+ def load_review(path):
338
+ with open(path, "r") as json_file:
339
+ loaded = json.load(json_file)
340
+ return loaded["review"]
341
+
342
+
343
+ # get directory of this file
344
+ dir_path = os.path.dirname(os.path.realpath(__file__))
345
+
346
+ fewshot_papers = [
347
+ os.path.join(dir_path, "fewshot_examples/132_automated_relational.pdf"),
348
+ os.path.join(dir_path, "fewshot_examples/attention.pdf"),
349
+ os.path.join(dir_path, "fewshot_examples/2_carpe_diem.pdf"),
350
+ ]
351
+
352
+ fewshot_reviews = [
353
+ os.path.join(dir_path, "fewshot_examples/132_automated_relational.json"),
354
+ os.path.join(dir_path, "fewshot_examples/attention.json"),
355
+ os.path.join(dir_path, "fewshot_examples/2_carpe_diem.json"),
356
+ ]
357
+
358
+
359
+ def get_review_fewshot_examples(num_fs_examples=1):
360
+ fewshot_prompt = """
361
+ Below are some sample reviews, copied from previous machine learning conferences.
362
+ Note that while each review is formatted differently according to each reviewer's style, the reviews are well-structured and therefore easy to navigate.
363
+ """
364
+ for paper, review in zip(
365
+ fewshot_papers[:num_fs_examples], fewshot_reviews[:num_fs_examples]
366
+ ):
367
+ txt_path = paper.replace(".pdf", ".txt")
368
+ if os.path.exists(txt_path):
369
+ with open(txt_path, "r") as f:
370
+ paper_text = f.read()
371
+ else:
372
+ paper_text = load_paper(paper)
373
+ review_text = load_review(review)
374
+ fewshot_prompt += f"""
375
+ Paper:
376
+
377
+ ```
378
+ {paper_text}
379
+ ```
380
+
381
+ Review:
382
+
383
+ ```
384
+ {review_text}
385
+ ```
386
+ """
387
+
388
+ return fewshot_prompt
389
+
390
+
391
+ meta_reviewer_system_prompt = """You are an Area Chair at a machine learning conference.
392
+ You are in charge of meta-reviewing a paper that was reviewed by {reviewer_count} reviewers.
393
+ Your job is to aggregate the reviews into a single meta-review in the same format.
394
+ Be critical and cautious in your decision, find consensus, and respect the opinion of all the reviewers."""
395
+
396
+
397
+ def get_meta_review(model, client, temperature, reviews):
398
+ # Write a meta-review from a set of individual reviews
399
+ review_text = ""
400
+ for i, r in enumerate(reviews):
401
+ review_text += f"""
402
+ Review {i + 1}/{len(reviews)}:
403
+ ```
404
+ {json.dumps(r)}
405
+ ```
406
+ """
407
+ base_prompt = neurips_form + review_text
408
+
409
+ llm_review, msg_history = get_response_from_llm(
410
+ base_prompt,
411
+ model=model,
412
+ client=client,
413
+ system_message=meta_reviewer_system_prompt.format(reviewer_count=len(reviews)),
414
+ print_debug=False,
415
+ msg_history=None,
416
+ temperature=temperature,
417
+ )
418
+ meta_review = format_llm_review_json(llm_review)
419
+ return meta_review
420
+
421
+
422
+ def perform_improvement(review, coder):
423
+ improvement_prompt = '''The following review has been created for your research paper:
424
+ """
425
+ {review}
426
+ """
427
+
428
+ Improve the text using the review.'''.format(
429
+ review=json.dumps(review)
430
+ )
431
+ coder_out = coder.run(improvement_prompt)
ai_scientist/perform_writeup.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import shutil
7
+ import subprocess
8
+ from typing import Optional, Tuple
9
+
10
+ from strictjson import strict_json
11
+
12
+ from ai_scientist.generate_ideas import search_for_papers
13
+ from ai_scientist.llm import (
14
+ allchoices,
15
+ extract_json_between_markers,
16
+ get_response_from_llm,
17
+ llm_json_auto_correct,
18
+ )
19
+
20
+
21
+ def format_citation_first_json(text):
22
+ res = strict_json(
23
+ system_prompt="You are a JSON formatter",
24
+ user_prompt=text,
25
+ return_as_json=True,
26
+ output_format={
27
+ "Description": "A precise description of the required edit, along with the proposed text and location where it should be made",
28
+ "Query": "The search query to find the paper (e.g. attention is all you need)",
29
+ },
30
+ llm=llm_json_auto_correct,
31
+ )
32
+ text = json.loads(res)
33
+ return text
34
+
35
+
36
+ def format_citation_second_json(text):
37
+ res = strict_json(
38
+ system_prompt="You are a JSON formatter",
39
+ user_prompt=text,
40
+ return_as_json=True,
41
+ output_format={
42
+ "Selected": "A list of the indices of the selected papers to be cited, e.g. '[0, 1]'. Can be '[]' if no papers are selected. This must be a string",
43
+ "Description": "Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex",
44
+ },
45
+ llm=llm_json_auto_correct,
46
+ )
47
+ text = json.loads(res)
48
+ return text
49
+
50
+
51
+ # GENERATE LATEX
52
+ def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_corrections=5):
53
+ folder = osp.abspath(folder_name)
54
+ cwd = osp.join(folder, "latex") # Fixed potential issue with path
55
+ writeup_file = osp.join(cwd, "template.tex")
56
+
57
+ # Check all references are valid and in the references.bib file
58
+ with open(writeup_file, "r") as f:
59
+ tex_text = f.read()
60
+ cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
61
+ references_bib = re.search(
62
+ r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
63
+ tex_text,
64
+ re.DOTALL,
65
+ )
66
+ if references_bib is None:
67
+ print("No references.bib found in template.tex")
68
+ return
69
+ bib_text = references_bib.group(1)
70
+ cites = [cite.strip() for item in cites for cite in item.split(",")]
71
+ for cite in cites:
72
+ if cite not in bib_text:
73
+ print(f"Reference {cite} not found in references.")
74
+ prompt = f"""Reference {cite} not found in references.bib. Is this included under a different name?
75
+ If so, please modify the citation in template.tex to match the name in references.bib at the top. Otherwise, remove the cite."""
76
+ coder.run(prompt)
77
+
78
+ # Check all included figures are actually in the directory.
79
+ with open(writeup_file, "r") as f:
80
+ tex_text = f.read()
81
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
82
+ all_figs = [f for f in os.listdir(folder) if f.endswith(".png")]
83
+ for figure in referenced_figs:
84
+ if figure not in all_figs:
85
+ print(f"Figure {figure} not found in directory.")
86
+ prompt = f"""The image {figure} not found in the directory. The images in the directory are: {all_figs}.
87
+ Please ensure that the figure is in the directory and that the filename is correct. Check the notes to see what each figure contains."""
88
+ coder.run(prompt)
89
+
90
+ # Remove duplicate figures.
91
+ with open(writeup_file, "r") as f:
92
+ tex_text = f.read()
93
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
94
+ duplicates = {x for x in referenced_figs if referenced_figs.count(x) > 1}
95
+ if duplicates:
96
+ for dup in duplicates:
97
+ print(f"Duplicate figure found: {dup}.")
98
+ prompt = f"""Duplicate figures found: {dup}. Ensure any figure is only included once.
99
+ If duplicated, identify the best location for the figure and remove any other."""
100
+ coder.run(prompt)
101
+
102
+ # Remove duplicate section headers.
103
+ with open(writeup_file, "r") as f:
104
+ tex_text = f.read()
105
+ sections = re.findall(r"\\section{([^}]*)}", tex_text)
106
+ duplicates = {x for x in sections if sections.count(x) > 1}
107
+ if duplicates:
108
+ for dup in duplicates:
109
+ print(f"Duplicate section header found: {dup}")
110
+ prompt = f"""Duplicate section header found: {dup}. Ensure any section header is declared once.
111
+ If duplicated, identify the best location for the section header and remove any other."""
112
+ coder.run(prompt)
113
+
114
+ # Iteratively fix any LaTeX bugs
115
+ for i in range(num_error_corrections):
116
+ # Filter trivial bugs in chktex
117
+ check_output = os.popen(f"chktex {writeup_file} -q -n2 -n24 -n13 -n1").read()
118
+ if check_output:
119
+ prompt = f"""Please fix the following LaTeX errors in `template.tex` guided by the output of `chktek`:
120
+ {check_output}.
121
+
122
+ Make the minimal fix required and do not remove or change any packages.
123
+ Pay attention to any accidental uses of HTML syntax, e.g. </end instead of \\end.
124
+ """
125
+ coder.run(prompt)
126
+ else:
127
+ break
128
+ compile_latex(cwd, pdf_file, timeout=timeout)
129
+
130
+
131
+ def compile_latex(cwd, pdf_file, timeout=30):
132
+ print("GENERATING LATEX")
133
+
134
+ commands = [
135
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
136
+ ["bibtex", "template"],
137
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
138
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
139
+ ]
140
+
141
+ for command in commands:
142
+ try:
143
+ result = subprocess.run(
144
+ command,
145
+ cwd=cwd,
146
+ stdout=subprocess.PIPE,
147
+ stderr=subprocess.PIPE,
148
+ text=True,
149
+ timeout=timeout,
150
+ )
151
+ print("Standard Output:\n", result.stdout)
152
+ print("Standard Error:\n", result.stderr)
153
+ except subprocess.TimeoutExpired:
154
+ print(f"Latex timed out after {timeout} seconds")
155
+ except subprocess.CalledProcessError as e:
156
+ print(f"Error running command {' '.join(command)}: {e}")
157
+
158
+ print("FINISHED GENERATING LATEX")
159
+
160
+ # Attempt to move the PDF to the desired location
161
+ try:
162
+ shutil.move(osp.join(cwd, "template.pdf"), pdf_file)
163
+ except FileNotFoundError:
164
+ print("Failed to rename PDF.")
165
+
166
+
167
+ per_section_tips = {
168
+ "Abstract": """
169
+ - TL;DR of the paper
170
+ - What are we trying to do and why is it relevant?
171
+ - Why is this hard?
172
+ - How do we solve it (i.e. our contribution!)
173
+ - How do we verify that we solved it (e.g. Experiments and results)
174
+
175
+ Please make sure the abstract reads smoothly and is well-motivated. This should be one continuous paragraph with no breaks between the lines.
176
+ """,
177
+ "Introduction": """
178
+ - Longer version of the Abstract, i.e. of the entire paper
179
+ - What are we trying to do and why is it relevant?
180
+ - Why is this hard?
181
+ - How do we solve it (i.e. our contribution!)
182
+ - How do we verify that we solved it (e.g. Experiments and results)
183
+ - New trend: specifically list your contributions as bullet points
184
+ - Extra space? Future work!
185
+ """,
186
+ "Related Work": """
187
+ - Academic siblings of our work, i.e. alternative attempts in literature at trying to solve the same problem.
188
+ - Goal is to “Compare and contrast” - how does their approach differ in either assumptions or method? If their method is applicable to our Problem Setting I expect a comparison in the experimental section. If not, there needs to be a clear statement why a given method is not applicable.
189
+ - Note: Just describing what another paper is doing is not enough. We need to compare and contrast.
190
+ """,
191
+ "Background": """
192
+ - Academic Ancestors of our work, i.e. all concepts and prior work that are required for understanding our method.
193
+ - Usually includes a subsection, Problem Setting, which formally introduces the problem setting and notation (Formalism) for our method. Highlights any specific assumptions that are made that are unusual.
194
+ - Note: If our paper introduces a novel problem setting as part of its contributions, it's best to have a separate Section.
195
+ """,
196
+ "Method": """
197
+ - What we do. Why we do it. All described using the general Formalism introduced in the Problem Setting and building on top of the concepts / foundations introduced in Background.
198
+ """,
199
+ "Experimental Setup": """
200
+ - How do we test that our stuff works? Introduces a specific instantiation of the Problem Setting and specific implementation details of our Method for this Problem Setting.
201
+ - Do not imagine unknown hardware details.
202
+ - Includes a description of the dataset, evaluation metrics, important hyperparameters, and implementation details.
203
+ """,
204
+ "Results": """
205
+ - Shows the results of running Method on our problem described in Experimental Setup.
206
+ - Includes statements on hyperparameters and other potential issues of fairness.
207
+ - Only includes results that have actually been run and saved in the logs. Do not hallucinate results that don't exist.
208
+ - If results exist: compares to baselines and includes statistics and confidence intervals.
209
+ - If results exist: includes ablation studies to show that specific parts of the method are relevant.
210
+ - Discusses limitations of the method.
211
+ - Make sure to include all the results from the experiments, and include all relevant figures.
212
+ """,
213
+ "Conclusion": """
214
+ - Brief recap of the entire paper.
215
+ - To keep going with the analogy, you can think of future work as (potential) academic offspring.
216
+ """,
217
+ }
218
+
219
+ error_list = """- Unenclosed math symbols
220
+ - Only reference figures that exist in our directory
221
+ - LaTeX syntax errors
222
+ - Numerical results that do not come from explicit experiments and logs
223
+ - Repeatedly defined figure labels
224
+ - References to papers that are not in the .bib file, DO NOT ADD ANY NEW CITATIONS!
225
+ - Unnecessary verbosity or repetition, unclear text
226
+ - Results or insights in the `notes.txt` that have not yet need included
227
+ - Any relevant figures that have not yet been included in the text
228
+ - Closing any \\begin{{figure}} with a \\end{{figure}} and \\begin{{table}} with a \\end{{table}}, etc.
229
+ - Duplicate headers, e.g. duplicated \\section{{Introduction}} or \\end{{document}}
230
+ - Unescaped symbols, e.g. shakespeare_char should be shakespeare\\_char in text
231
+ - Incorrect closing of environments, e.g. </end{{figure}}> instead of \\end{{figure}}
232
+ """
233
+
234
+ refinement_prompt = (
235
+ """Great job! Now criticize and refine only the {section} that you just wrote.
236
+ Make this complete in this pass, do not leave any placeholders.
237
+
238
+ Pay particular attention to fixing any errors such as:
239
+ """
240
+ + error_list
241
+ )
242
+
243
+ second_refinement_prompt = (
244
+ """Criticize and refine the {section} only. Recall the advice:
245
+ {tips}
246
+ Make this complete in this pass, do not leave any placeholders.
247
+
248
+ Pay attention to how it fits in with the rest of the paper.
249
+ Identify any redundancies (e.g. repeated figures or repeated text), if there are any, decide where in the paper things should be cut.
250
+ Identify where we can save space, and be more concise without weakening the message of the text.
251
+ Fix any remaining errors as before:
252
+ """
253
+ + error_list
254
+ )
255
+
256
+ # CITATION HELPERS
257
+ citation_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
258
+ You have already written an initial draft of the paper and now you are looking to add missing citations to related papers throughout the paper.
259
+ The related work section already has some initial comments on which papers to add and discuss.
260
+
261
+ Focus on completing the existing write-up and do not add entirely new elements unless necessary.
262
+ Ensure every point in the paper is substantiated with sufficient evidence.
263
+ Feel free to add more cites to a particular point if there is only one or two references.
264
+ Ensure no paper is cited without a corresponding reference in the `references.bib` file.
265
+ Ensure each paragraph of the related work has sufficient background, e.g. a few papers cited.
266
+ You will be given access to the Semantic Scholar API, only add citations that you have found using the API.
267
+ Aim to discuss a broad range of relevant papers, not just the most popular ones.
268
+ Make sure not to copy verbatim from prior literature to avoid plagiarism.
269
+
270
+ You will be prompted to give a precise description of where and how to add the cite, and a search query for the paper to be cited.
271
+ Finally, you will select the most relevant cite from the search results (top 10 results will be shown).
272
+ You will have {total_rounds} rounds to add to the references, but do not need to use them all.
273
+
274
+ DO NOT ADD A CITATION THAT ALREADY EXISTS!"""
275
+
276
+ citation_first_prompt = '''Round {current_round}/{total_rounds}:
277
+
278
+ You have written this LaTeX draft so far:
279
+
280
+ """
281
+ {draft}
282
+ """
283
+
284
+ Identify the most important citation that you still need to add, and the query to find the paper.
285
+
286
+ Respond in the following format:
287
+
288
+ THOUGHT:
289
+ <THOUGHT>
290
+
291
+ RESPONSE:
292
+ ```json
293
+ <JSON>
294
+ ```
295
+
296
+ In <THOUGHT>, first briefly reason over the paper and identify where citations should be added.
297
+ If no more citations are needed, add "No more citations needed" to your thoughts.
298
+ Do not add "No more citations needed" if you are adding citations this round.
299
+
300
+ In <JSON>, respond in JSON format with the following fields:
301
+ - "Description": A precise description of the required edit, along with the proposed text and location where it should be made.
302
+ - "Query": The search query to find the paper (e.g. attention is all you need).
303
+
304
+ Ensure the description is sufficient to make the change without further context. Someone else will make the change.
305
+ The query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
306
+ This JSON will be automatically parsed, so ensure the format is precise.'''
307
+
308
+ citation_second_prompt = """Search has recovered the following articles:
309
+
310
+ {papers}
311
+
312
+ Respond in the following format:
313
+
314
+ THOUGHT:
315
+ <THOUGHT>
316
+
317
+ RESPONSE:
318
+ ```json
319
+ <JSON>
320
+ ```
321
+
322
+ In <THOUGHT>, first briefly reason over the search results and identify which citation best fits your paper and the location is to be added at.
323
+ If none are appropriate, add "Do not add any" to your thoughts.
324
+
325
+ In <JSON>, respond in JSON format with the following fields:
326
+ - "Selected": A list of the indices of the selected papers to be cited, e.g. "[0, 1]". Can be "[]" if no papers are selected. This must be a string.
327
+ - "Description": Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex!!!
328
+
329
+ Do not select papers that are already in the `references.bib` file at the top of the draft, or if the same citation exists under a different name.
330
+ This JSON will be automatically parsed, so ensure the format is precise."""
331
+
332
+
333
+ def get_citation_aider_prompt(
334
+ client, model, draft, current_round, total_rounds
335
+ ) -> Tuple[Optional[str], bool]:
336
+ msg_history = []
337
+ try:
338
+ text, msg_history = get_response_from_llm(
339
+ citation_first_prompt.format(
340
+ draft=draft, current_round=current_round, total_rounds=total_rounds
341
+ ),
342
+ client=client,
343
+ model=model,
344
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
345
+ msg_history=msg_history,
346
+ )
347
+ if "No more citations needed" in text:
348
+ print("No more citations needed.")
349
+ return None, True
350
+
351
+ ## PARSE OUTPUT
352
+ json_output = format_citation_first_json(text)
353
+ assert json_output is not None, "Failed to extract JSON from LLM output"
354
+ query = json_output["Query"]
355
+ papers = search_for_papers(query)
356
+ except Exception as e:
357
+ print(f"Error: {e}")
358
+ return None, False
359
+
360
+ if papers is None:
361
+ print("No papers found.")
362
+ return None, False
363
+
364
+ paper_strings = []
365
+ for i, paper in enumerate(papers):
366
+ paper_strings.append(
367
+ """{i}: {title}. {authors}. {venue}, {year}.\nAbstract: {abstract}""".format(
368
+ i=i,
369
+ title=paper["title"],
370
+ authors=paper["authors"],
371
+ venue=paper["venue"],
372
+ year=paper["year"],
373
+ abstract=paper["abstract"],
374
+ )
375
+ )
376
+ papers_str = "\n\n".join(paper_strings)
377
+
378
+ try:
379
+ text, msg_history = get_response_from_llm(
380
+ citation_second_prompt.format(
381
+ papers=papers_str,
382
+ current_round=current_round,
383
+ total_rounds=total_rounds,
384
+ ),
385
+ client=client,
386
+ model=model,
387
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
388
+ msg_history=msg_history,
389
+ )
390
+ if "Do not add any" in text:
391
+ print("Do not add any.")
392
+ return None, False
393
+ ## PARSE OUTPUT
394
+ json_output = format_citation_second_json(text)
395
+ assert json_output is not None, "Failed to extract JSON from LLM output"
396
+ desc = json_output["Description"]
397
+ selected_papers = json_output["Selected"]
398
+ selected_papers = str(selected_papers)
399
+
400
+ # convert to list
401
+ if selected_papers != "[]":
402
+ selected_papers = list(map(int, selected_papers.strip("[]").split(",")))
403
+ assert all(
404
+ [0 <= i < len(papers) for i in selected_papers]
405
+ ), "Invalid paper index"
406
+ bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers]
407
+ bibtex_string = "\n".join(bibtexs)
408
+ else:
409
+ return None, False
410
+
411
+ except Exception as e:
412
+ print(f"Error: {e}")
413
+ return None, False
414
+
415
+ # Add citation to draft
416
+ aider_format = '''The following citations have just been added to the end of the `references.bib` file definition at the top of the file:
417
+ """
418
+ {bibtex}
419
+ """
420
+ You do not need to add them yourself.
421
+ ABSOLUTELY DO NOT ADD IT AGAIN!!!
422
+
423
+ Make the proposed change to the draft incorporating these new cites:
424
+ {description}
425
+
426
+ Use your judgment for whether these should be cited anywhere else.
427
+ Make sure that any citation precisely matches the name in `references.bib`. Change its name to the correct name in the bibtex if needed.
428
+ Ensure the citation is well-integrated into the text.'''
429
+
430
+ aider_prompt = (
431
+ aider_format.format(bibtex=bibtex_string, description=desc)
432
+ + """\n You must use \cite or \citet to reference papers, do not manually type out author names."""
433
+ )
434
+ return aider_prompt, False
435
+
436
+
437
+ # PERFORM WRITEUP
438
+ def perform_writeup(
439
+ idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20
440
+ ):
441
+ # CURRENTLY ASSUMES LATEX
442
+ abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
443
+
444
+ First, please fill in the "Title" and "Abstract" sections of the writeup.
445
+
446
+ Some tips are provided below:
447
+ {per_section_tips["Abstract"]}
448
+
449
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
450
+
451
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
452
+ """
453
+ coder_out = coder.run(abstract_prompt)
454
+ coder_out = coder.run(
455
+ refinement_prompt.format(section="Abstract")
456
+ .replace(r"{{", "{")
457
+ .replace(r"}}", "}")
458
+ )
459
+ for section in [
460
+ "Introduction",
461
+ "Background",
462
+ "Method",
463
+ "Experimental Setup",
464
+ "Results",
465
+ "Conclusion",
466
+ ]:
467
+ section_prompt = f"""Please fill in the {section} of the writeup. Some tips are provided below:
468
+ {per_section_tips[section]}
469
+
470
+ Be sure to use \cite or \citet where relevant, referring to the works provided in the file.
471
+ Do not cite anything that is not already in `references.bib`. Do not add any new entries to this.
472
+
473
+ Keep the experimental results (figures and tables) only in the Results section, and make sure that any captions are filled in.
474
+ In this pass, do not reference anything in later sections of the paper.
475
+
476
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
477
+
478
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
479
+ """
480
+ coder_out = coder.run(section_prompt)
481
+ coder_out = coder.run(
482
+ refinement_prompt.format(section=section)
483
+ .replace(r"{{", "{")
484
+ .replace(r"}}", "}")
485
+ )
486
+
487
+ # SKETCH THE RELATED WORK
488
+ section_prompt = f"""Please fill in the Related Work of the writeup. Some tips are provided below:
489
+
490
+ {per_section_tips["Related Work"]}
491
+
492
+ For this section, very briefly sketch out the structure of the section, and clearly indicate what papers you intend to include.
493
+ Do this all in LaTeX comments using %.
494
+ The related work should be concise, only plan to discuss the most relevant work.
495
+ Do not modify `references.bib` to add any new citations, this will be filled in at a later stage.
496
+
497
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
498
+ """
499
+ coder_out = coder.run(section_prompt)
500
+
501
+ # Fill paper with cites.
502
+ for _ in range(num_cite_rounds):
503
+ with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
504
+ draft = f.read()
505
+ prompt, done = get_citation_aider_prompt(
506
+ cite_client, cite_model, draft, _, num_cite_rounds
507
+ )
508
+ if done:
509
+ break
510
+ if prompt is not None:
511
+ # extract bibtex string
512
+ bibtex_string = prompt.split('"""')[1]
513
+ # insert this into draft before the "\end{filecontents}" line
514
+ search_str = r"\end{filecontents}"
515
+ draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
516
+ with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
517
+ f.write(draft)
518
+ coder_out = coder.run(prompt)
519
+
520
+ coder_out = coder.run(
521
+ refinement_prompt.format(section="Related Work")
522
+ .replace(r"{{", "{")
523
+ .replace(r"}}", "}")
524
+ )
525
+
526
+ ## SECOND REFINEMENT LOOP
527
+ coder.run(
528
+ """Great job! Now that there is a complete draft of the entire paper, let's refine each section again.
529
+ First, re-think the Title if necessary. Keep this concise and descriptive of the paper's concept, but try by creative with it."""
530
+ )
531
+ for section in [
532
+ "Abstract",
533
+ "Related Work",
534
+ "Introduction",
535
+ "Background",
536
+ "Method",
537
+ "Experimental Setup",
538
+ "Results",
539
+ "Conclusion",
540
+ ]:
541
+ coder_out = coder.run(
542
+ second_refinement_prompt.format(
543
+ section=section, tips=per_section_tips[section]
544
+ )
545
+ .replace(r"{{", "{")
546
+ .replace(r"}}", "}")
547
+ )
548
+
549
+ generate_latex(coder, folder_name, f"{folder_name}/{idea['Name']}.pdf")
550
+
551
+
552
+ if __name__ == "__main__":
553
+ import json
554
+
555
+ from aider.coders import Coder
556
+ from aider.io import InputOutput
557
+ from aider.models import Model
558
+
559
+ parser = argparse.ArgumentParser(description="Perform writeup for a project")
560
+ parser.add_argument("--folder", type=str)
561
+ parser.add_argument("--no-writing", action="store_true", help="Only generate")
562
+ parser.add_argument(
563
+ "--model",
564
+ type=str,
565
+ default="gpt-4o-2024-05-13",
566
+ choices=allchoices,
567
+ help="Model to use for AI Scientist.",
568
+ )
569
+ args = parser.parse_args()
570
+ if args.model == "claude-3-5-sonnet-20240620":
571
+ import anthropic
572
+
573
+ print(f"Using Anthropic API with model {args.model}.")
574
+ client_model = "claude-3-5-sonnet-20240620"
575
+ client = anthropic.Anthropic()
576
+ elif args.model.startswith("bedrock") and "claude" in args.model:
577
+ import anthropic
578
+
579
+ # Expects: bedrock/<MODEL_ID>
580
+ client_model = args.model.split("/")[-1]
581
+
582
+ print(f"Using Amazon Bedrock with model {client_model}.")
583
+ client = anthropic.AnthropicBedrock()
584
+ elif args.model.startswith("vertex_ai") and "claude" in args.model:
585
+ import anthropic
586
+
587
+ # Expects: vertex_ai/<MODEL_ID>
588
+ client_model = args.model.split("/")[-1]
589
+
590
+ print(f"Using Vertex AI with model {client_model}.")
591
+ client = anthropic.AnthropicVertex()
592
+ elif args.model == "gpt-4o-2024-05-13":
593
+ import openai
594
+
595
+ print(f"Using OpenAI API with model {args.model}.")
596
+ client_model = "gpt-4o-2024-05-13"
597
+ client = openai.OpenAI()
598
+ elif args.model == "deepseek-coder-v2-0724":
599
+ import openai
600
+
601
+ print(f"Using OpenAI API with {args.model}.")
602
+ client_model = "deepseek-coder-v2-0724"
603
+ client = openai.OpenAI(
604
+ api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
605
+ )
606
+
607
+ # ----------------------------------------------------
608
+
609
+ elif args.model == "Qwen/Qwen2.5-72B-Instruct":
610
+ # elif args.model.startswith("hyperbolic"):
611
+ print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
612
+
613
+ import openai
614
+ import os
615
+ # client_model = args.model[11:]
616
+ client_model = args.model
617
+ client = openai.OpenAI(
618
+ api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
619
+ )
620
+ # ----------------------------------------------------
621
+ elif args.model == "llama3.1-405b":
622
+ import openai
623
+
624
+ print(f"Using OpenAI API with {args.model}.")
625
+ client_model = "meta-llama/llama-3.1-405b-instruct"
626
+ client = openai.OpenAI(
627
+ api_key=os.environ["OPENROUTER_API_KEY"],
628
+ base_url="https://openrouter.ai/api/v1",
629
+ )
630
+
631
+ elif args.model.startswith("ollama"):
632
+ import openai
633
+
634
+ print(f"Using Ollama with {args.model}.")
635
+ client_model = args.model.split("/")[-1]
636
+ client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
637
+ else:
638
+ raise ValueError(f"Model {args.model} not recognized.")
639
+
640
+
641
+ print("Make sure you cleaned the Aider logs if re-generating the writeup!")
642
+ folder_name = args.folder
643
+ idea_name = osp.basename(folder_name)
644
+ exp_file = osp.join(folder_name, "experiment.py")
645
+ vis_file = osp.join(folder_name, "plot.py")
646
+ notes = osp.join(folder_name, "notes.txt")
647
+
648
+ model = args.model
649
+
650
+ writeup_file = osp.join(folder_name, "latex", "template.tex")
651
+ ideas_file = osp.join(folder_name, "ideas.json")
652
+ with open(ideas_file, "r") as f:
653
+ ideas = json.load(f)
654
+ for idea in ideas:
655
+ if idea["Name"] in idea_name:
656
+ print(f"Found idea: {idea['Name']}")
657
+ break
658
+ if idea["Name"] not in idea_name:
659
+ raise ValueError(f"Idea {idea_name} not found")
660
+ fnames = [exp_file, writeup_file, notes]
661
+ io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
662
+
663
+
664
+
665
+ # AIDER CHAT INITIALIZATION CODE
666
+
667
+ if args.model == "deepseek-ai/DeepSeek-V2.5":
668
+ print("aider chosen deepseek")
669
+ main_model = Model("deepseek-chat")
670
+
671
+ elif args.model == "deepseek-coder-v2-0724":
672
+ main_model = Model("deepseek-ai/DeepSeek-V2.5")
673
+
674
+ elif args.model == "llama3.1-405b":
675
+ main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
676
+
677
+ # ----------------------------------------------------
678
+
679
+ elif args.model == "hyperbolic/Qwen/Qwen2.5-72B-Instruct":
680
+ print("aider model chosen")
681
+ # main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
682
+ main_model = Model("hyperbolic/Qwen/Qwen2.5-72B-Instruct")
683
+
684
+ elif args.model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
685
+ main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
686
+
687
+ # ----------------------------------------------------
688
+
689
+ else:
690
+ print("hello world")
691
+ main_model = Model(model)
692
+
693
+ coder = Coder.create(
694
+ main_model=main_model,
695
+ fnames=fnames,
696
+ io=io,
697
+ stream=False,
698
+ use_git=False,
699
+ edit_format="diff",
700
+ )
701
+ if args.no_writing:
702
+ generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
703
+ else:
704
+ try:
705
+ perform_writeup(idea, folder_name, coder, client, client_model)
706
+ except Exception as e:
707
+ print(f"Failed to perform writeup: {e}")
cuda-keyring_1.0-1_all.deb ADDED
Binary file (4.33 kB). View file
 
cuda-keyring_1.0-1_all.deb.1 ADDED
Binary file (4.33 kB). View file
 
cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f03712f53fbe7a39a8d82f8ba186e9dc44fece31b8b2230dded07153a57a27dc
3
+ size 2306981744
data/enwik8/enwik8 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b49720ec4d78c3c9fabaee6e4179a5e997302b3a70029f30f2d582218c024a8
3
+ size 100000000
data/enwik8/enwik8.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547994d9980ebed1288380d652999f38a14fe291a6247c157c3d33d4932534bc
3
+ size 36445475
data/enwik8/meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e6ccdb5f07974d4d462bf8acbc1b5a7225f4da63453df9f67dd7bf6d976c386
3
+ size 2211
data/enwik8/prepare.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare the enwik8 dataset for character-level language modeling.
3
+ So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
4
+ Will save train.bin, val.bin containing the ids, and meta.pkl containing the
5
+ encoder and decoder and some other related info.
6
+ """
7
+ import os
8
+ import pickle
9
+ import requests
10
+ import numpy as np
11
+
12
+ # download the enwik8 dataset
13
+ input_file_path = os.path.join(os.path.dirname(__file__), 'enwik8')
14
+ if not os.path.exists(input_file_path):
15
+ data_url = 'http://mattmahoney.net/dc/enwik8.zip'
16
+ r = requests.get(data_url)
17
+ with open(os.path.join(os.path.dirname(__file__), 'enwik8.zip'), 'wb') as f:
18
+ f.write(r.content)
19
+
20
+ # unzip the enwik8 dataset
21
+ import zipfile
22
+ with zipfile.ZipFile(os.path.join(os.path.dirname(__file__), 'enwik8.zip'), 'r') as zip_ref:
23
+ zip_ref.extractall(os.path.dirname(__file__))
24
+
25
+ with open(input_file_path, 'r', encoding='latin-1') as f:
26
+ data = f.read()
27
+ print(f"length of dataset in characters: {len(data):,}")
28
+
29
+ # get all the unique characters that occur in this text
30
+ chars = sorted(list(set(data)))
31
+ vocab_size = len(chars)
32
+ print("all the unique characters:", ''.join(chars))
33
+ print(f"vocab size: {vocab_size:,}")
34
+
35
+ # create a mapping from characters to integers
36
+ stoi = { ch:i for i,ch in enumerate(chars) }
37
+ itos = { i:ch for i,ch in enumerate(chars) }
38
+ def encode(s):
39
+ return [stoi[c] for c in s] # encoder: take a string, output a list of integers
40
+ def decode(l):
41
+ return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
42
+
43
+ # create the train, validation, and test splits
44
+ n = len(data)
45
+ num_test_chars = 5000000
46
+ train_data = data[: -2 * num_test_chars]
47
+ val_data = data[-2 * num_test_chars: -num_test_chars]
48
+ test_data = data[-num_test_chars:]
49
+
50
+ # encode all splits to integers
51
+ train_ids = encode(train_data)
52
+ val_ids = encode(val_data)
53
+ test_ids = encode(test_data)
54
+
55
+ print(f"train has {len(train_ids):,} tokens")
56
+ print(f"val has {len(val_ids):,} tokens")
57
+ print(f"test has {len(test_ids):,} tokens")
58
+
59
+ # export to bin files
60
+ train_ids = np.array(train_ids, dtype=np.uint16)
61
+ val_ids = np.array(val_ids, dtype=np.uint16)
62
+ test_ids = np.array(test_ids, dtype=np.uint16)
63
+
64
+ train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
65
+ val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
66
+ test_ids.tofile(os.path.join(os.path.dirname(__file__), 'test.bin'))
67
+
68
+ # save the meta information as well, to help us encode/decode later
69
+ meta = {
70
+ 'vocab_size': vocab_size,
71
+ 'itos': itos,
72
+ 'stoi': stoi,
73
+ }
74
+ with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
75
+ pickle.dump(meta, f)
data/enwik8/test.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1627449a3615fb097ee0dedfe5a702bd574c3084c2c69338a1819a40c0a4eb6c
3
+ size 10000000
data/enwik8/train.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd58505a13f96a801113b7578e886b550c8c74678f425d711559e6728c791f14
3
+ size 180000000
data/enwik8/val.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f4b53738dd5eacb29568913815a43415eaaa195c33a2935580dc2055e13954c
3
+ size 10000000