Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +16 -0
- .gitignore +2 -0
- .ipynb_checkpoints/launch_scientist-checkpoint.py +500 -0
- DockerFile +0 -0
- LICENSE +201 -0
- Miniconda3-latest-Linux-x86_64.sh +3 -0
- README.md +312 -0
- ai_scientist/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- ai_scientist/.ipynb_checkpoints/generate_ideas-checkpoint.py +637 -0
- ai_scientist/.ipynb_checkpoints/llm-checkpoint.py +358 -0
- ai_scientist/.ipynb_checkpoints/perform_experiments-checkpoint.py +166 -0
- ai_scientist/.ipynb_checkpoints/perform_writeup-checkpoint.py +707 -0
- ai_scientist/Untitled.ipynb +62 -0
- ai_scientist/__init__.py +0 -0
- ai_scientist/__pycache__/__init__.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/__init__.cpython-312.pyc +0 -0
- ai_scientist/__pycache__/generate_ideas.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/generate_ideas.cpython-312.pyc +0 -0
- ai_scientist/__pycache__/llm.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/llm.cpython-312.pyc +0 -0
- ai_scientist/__pycache__/perform_experiments.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/perform_experiments.cpython-312.pyc +0 -0
- ai_scientist/__pycache__/perform_review.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/perform_review.cpython-312.pyc +0 -0
- ai_scientist/__pycache__/perform_writeup.cpython-311.pyc +0 -0
- ai_scientist/__pycache__/perform_writeup.cpython-312.pyc +0 -0
- ai_scientist/fewshot_examples/132_automated_relational.json +3 -0
- ai_scientist/fewshot_examples/132_automated_relational.pdf +3 -0
- ai_scientist/fewshot_examples/132_automated_relational.txt +1190 -0
- ai_scientist/fewshot_examples/2_carpe_diem.json +3 -0
- ai_scientist/fewshot_examples/2_carpe_diem.pdf +0 -0
- ai_scientist/fewshot_examples/2_carpe_diem.txt +1035 -0
- ai_scientist/fewshot_examples/attention.json +3 -0
- ai_scientist/fewshot_examples/attention.pdf +0 -0
- ai_scientist/fewshot_examples/attention.txt +662 -0
- ai_scientist/generate_ideas.py +637 -0
- ai_scientist/llm.py +359 -0
- ai_scientist/perform_experiments.py +166 -0
- ai_scientist/perform_review.py +431 -0
- ai_scientist/perform_writeup.py +707 -0
- cuda-keyring_1.0-1_all.deb +0 -0
- cuda-keyring_1.0-1_all.deb.1 +0 -0
- cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb +3 -0
- data/enwik8/enwik8 +3 -0
- data/enwik8/enwik8.zip +3 -0
- data/enwik8/meta.pkl +3 -0
- data/enwik8/prepare.py +75 -0
- data/enwik8/test.bin +3 -0
- data/enwik8/train.bin +3 -0
- data/enwik8/val.bin +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Miniconda3-latest-Linux-x86_64.sh filter=lfs diff=lfs merge=lfs -text
|
37 |
+
ai_scientist/fewshot_examples/132_automated_relational.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
+
cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/enwik8/enwik8 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/text8/text8 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
docs/adaptive_dual_scale_denoising.jpeg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
example_papers/adaptive_dual_scale_denoising/adaptive_dual_scale_denoising.pdf filter=lfs diff=lfs merge=lfs -text
|
43 |
+
example_papers/adaptive_dual_scale_denoising.pdf filter=lfs diff=lfs merge=lfs -text
|
44 |
+
example_papers/data_augmentation_grokking/data_augmentation_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
45 |
+
example_papers/data_augmentation_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
46 |
+
example_papers/grid_based_noise_adaptation/grid_based_noise_adaptation.pdf filter=lfs diff=lfs merge=lfs -text
|
47 |
+
example_papers/grid_based_noise_adaptation.pdf filter=lfs diff=lfs merge=lfs -text
|
48 |
+
example_papers/layerwise_lr_grokking/layerwise_lr_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
49 |
+
example_papers/layerwise_lr_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
50 |
+
example_papers/weight_initialization_grokking/weight_initialization_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
51 |
+
example_papers/weight_initialization_grokking.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.aider*
|
2 |
+
.env
|
.ipynb_checkpoints/launch_scientist-checkpoint.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import multiprocessing
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
import openai
|
12 |
+
import torch
|
13 |
+
from aider.coders import Coder
|
14 |
+
from aider.io import InputOutput
|
15 |
+
from aider.models import Model
|
16 |
+
|
17 |
+
from ai_scientist.generate_ideas import check_idea_novelty, generate_ideas
|
18 |
+
from ai_scientist.llm import allchoices
|
19 |
+
from ai_scientist.perform_experiments import perform_experiments
|
20 |
+
from ai_scientist.perform_review import load_paper, perform_improvement, perform_review
|
21 |
+
from ai_scientist.perform_writeup import generate_latex, perform_writeup
|
22 |
+
|
23 |
+
NUM_REFLECTIONS = 3
|
24 |
+
|
25 |
+
|
26 |
+
def print_time():
|
27 |
+
print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
28 |
+
|
29 |
+
|
30 |
+
def parse_arguments():
|
31 |
+
parser = argparse.ArgumentParser(description="Run AI scientist experiments")
|
32 |
+
parser.add_argument(
|
33 |
+
"--skip-idea-generation",
|
34 |
+
action="store_true",
|
35 |
+
help="Skip idea generation and load existing ideas",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--skip-novelty-check",
|
39 |
+
action="store_true",
|
40 |
+
help="Skip novelty check and use existing ideas",
|
41 |
+
)
|
42 |
+
# add type of experiment (nanoGPT, Boston, etc.)
|
43 |
+
parser.add_argument(
|
44 |
+
"--experiment",
|
45 |
+
type=str,
|
46 |
+
default="nanoGPT_lite",
|
47 |
+
help="Experiment to run AI Scientist on.",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--model",
|
51 |
+
type=str,
|
52 |
+
default="Qwen/Qwen2.5-72B-Instruct",
|
53 |
+
choices=allchoices,
|
54 |
+
help="Model to use for AI Scientist.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--writeup",
|
58 |
+
type=str,
|
59 |
+
default="latex",
|
60 |
+
choices=["latex"],
|
61 |
+
help="What format to use for writeup",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--parallel",
|
65 |
+
type=int,
|
66 |
+
default=0,
|
67 |
+
help="Number of parallel processes to run. 0 for sequential execution.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--improvement",
|
71 |
+
action="store_true",
|
72 |
+
help="Improve based on reviews.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--gpus",
|
76 |
+
type=str,
|
77 |
+
default=None,
|
78 |
+
help="Comma-separated list of GPU IDs to use (e.g., '0,1,2'). If not specified, all available GPUs will be used.",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--num-ideas",
|
82 |
+
type=int,
|
83 |
+
default=2,
|
84 |
+
help="Number of ideas to generate",
|
85 |
+
)
|
86 |
+
return parser.parse_args()
|
87 |
+
|
88 |
+
|
89 |
+
def get_available_gpus(gpu_ids=None):
|
90 |
+
if gpu_ids is not None:
|
91 |
+
return [int(gpu_id) for gpu_id in gpu_ids.split(",")]
|
92 |
+
return list(range(torch.cuda.device_count()))
|
93 |
+
|
94 |
+
|
95 |
+
def worker(
|
96 |
+
queue,
|
97 |
+
base_dir,
|
98 |
+
results_dir,
|
99 |
+
model,
|
100 |
+
client,
|
101 |
+
client_model,
|
102 |
+
writeup,
|
103 |
+
improvement,
|
104 |
+
gpu_id,
|
105 |
+
):
|
106 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
107 |
+
print(f"Worker {gpu_id} started.")
|
108 |
+
while True:
|
109 |
+
idea = queue.get()
|
110 |
+
if idea is None:
|
111 |
+
break
|
112 |
+
success = do_idea(
|
113 |
+
base_dir,
|
114 |
+
results_dir,
|
115 |
+
idea,
|
116 |
+
model,
|
117 |
+
client,
|
118 |
+
client_model,
|
119 |
+
writeup,
|
120 |
+
improvement,
|
121 |
+
log_file=True,
|
122 |
+
)
|
123 |
+
print(f"Completed idea: {idea['Name']}, Success: {success}")
|
124 |
+
print(f"Worker {gpu_id} finished.")
|
125 |
+
|
126 |
+
|
127 |
+
def do_idea(
|
128 |
+
base_dir,
|
129 |
+
results_dir,
|
130 |
+
idea,
|
131 |
+
model,
|
132 |
+
client,
|
133 |
+
client_model,
|
134 |
+
writeup,
|
135 |
+
improvement,
|
136 |
+
log_file=False,
|
137 |
+
):
|
138 |
+
## CREATE PROJECT FOLDER
|
139 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
140 |
+
idea_name = f"{timestamp}_{idea['Name']}"
|
141 |
+
folder_name = osp.join(results_dir, idea_name)
|
142 |
+
assert not osp.exists(folder_name), f"Folder {folder_name} already exists."
|
143 |
+
destination_dir = folder_name
|
144 |
+
shutil.copytree(base_dir, destination_dir, dirs_exist_ok=True)
|
145 |
+
with open(osp.join(base_dir, "run_0", "final_info.json"), "r") as f:
|
146 |
+
baseline_results = json.load(f)
|
147 |
+
baseline_results = {k: v["means"] for k, v in baseline_results.items()}
|
148 |
+
exp_file = osp.join(folder_name, "experiment.py")
|
149 |
+
vis_file = osp.join(folder_name, "plot.py")
|
150 |
+
notes = osp.join(folder_name, "notes.txt")
|
151 |
+
with open(notes, "w") as f:
|
152 |
+
f.write(f"# Title: {idea['Title']}\n")
|
153 |
+
f.write(f"# Experiment description: {idea['Experiment']}\n")
|
154 |
+
f.write(f"## Run 0: Baseline\n")
|
155 |
+
f.write(f"Results: {baseline_results}\n")
|
156 |
+
f.write(f"Description: Baseline results.\n")
|
157 |
+
if log_file:
|
158 |
+
original_stdout = sys.stdout
|
159 |
+
original_stderr = sys.stderr
|
160 |
+
log_path = osp.join(folder_name, "log.txt")
|
161 |
+
log = open(log_path, "a")
|
162 |
+
sys.stdout = log
|
163 |
+
sys.stderr = log
|
164 |
+
try:
|
165 |
+
print_time()
|
166 |
+
print(f"*Starting idea: {idea_name}*")
|
167 |
+
## PERFORM EXPERIMENTS
|
168 |
+
fnames = [exp_file, vis_file, notes]
|
169 |
+
io = InputOutput(
|
170 |
+
yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt"
|
171 |
+
)
|
172 |
+
if model == "hybrid":
|
173 |
+
main_model = Model("claude-3-5-sonnet-20240620")
|
174 |
+
elif model == "deepseek-coder-v2-0724":
|
175 |
+
main_model = Model("deepseek-ai/DeepSeek-V2.5")
|
176 |
+
elif model == "llama3.1-405b":
|
177 |
+
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
|
178 |
+
|
179 |
+
# ----------------------------------------------------
|
180 |
+
|
181 |
+
elif args.model == "Qwen/Qwen2.5-72B-Instruct":
|
182 |
+
print("aider model chosen")
|
183 |
+
|
184 |
+
# main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
|
185 |
+
# main_model = Model("openai/Qwen2.5-72B-Instruct")
|
186 |
+
main_model = Model("friendli/Qwen2.5-72B-Instruct")
|
187 |
+
|
188 |
+
elif model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
|
189 |
+
main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
|
190 |
+
|
191 |
+
# ----------------------------------------------------
|
192 |
+
|
193 |
+
else:
|
194 |
+
main_model = Model(model)
|
195 |
+
coder = Coder.create(
|
196 |
+
main_model=main_model,
|
197 |
+
fnames=fnames,
|
198 |
+
io=io,
|
199 |
+
stream=False,
|
200 |
+
use_git=False,
|
201 |
+
edit_format="diff",
|
202 |
+
)
|
203 |
+
|
204 |
+
print_time()
|
205 |
+
print(f"*Starting Experiments*")
|
206 |
+
try:
|
207 |
+
success = perform_experiments(idea, folder_name, coder, baseline_results)
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Error during experiments: {e}")
|
210 |
+
print(f"Experiments failed for idea {idea_name}")
|
211 |
+
return False
|
212 |
+
|
213 |
+
if not success:
|
214 |
+
print(f"Experiments failed for idea {idea_name}")
|
215 |
+
return False
|
216 |
+
|
217 |
+
print_time()
|
218 |
+
print(f"*Starting Writeup*")
|
219 |
+
## PERFORM WRITEUP
|
220 |
+
if writeup == "latex":
|
221 |
+
writeup_file = osp.join(folder_name, "latex", "template.tex")
|
222 |
+
fnames = [exp_file, writeup_file, notes]
|
223 |
+
if model == "deepseek-coder-v2-0724":
|
224 |
+
main_model = Model("deepseek-ai/DeepSeek-V2.5")
|
225 |
+
elif model == "llama3.1-405b":
|
226 |
+
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
|
227 |
+
|
228 |
+
# ----------------------------------------------------
|
229 |
+
|
230 |
+
elif args.model == "Qwen/Qwen2.5-72B-Instruct":
|
231 |
+
print("aider model chosen")
|
232 |
+
# main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
|
233 |
+
main_model = Model("openai/Qwen/Qwen2.5-72B-Instruct")
|
234 |
+
|
235 |
+
elif model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
|
236 |
+
main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
|
237 |
+
|
238 |
+
# ----------------------------------------------------
|
239 |
+
else:
|
240 |
+
main_model = Model(model)
|
241 |
+
coder = Coder.create(
|
242 |
+
main_model=main_model,
|
243 |
+
fnames=fnames,
|
244 |
+
io=io,
|
245 |
+
stream=False,
|
246 |
+
use_git=False,
|
247 |
+
edit_format="diff",
|
248 |
+
)
|
249 |
+
try:
|
250 |
+
perform_writeup(idea, folder_name, coder, client, client_model)
|
251 |
+
except Exception as e:
|
252 |
+
print(f"Failed to perform writeup: {e}")
|
253 |
+
return False
|
254 |
+
print("Done writeup")
|
255 |
+
else:
|
256 |
+
raise ValueError(f"Writeup format {writeup} not supported.")
|
257 |
+
|
258 |
+
print_time()
|
259 |
+
print(f"*Starting Review*")
|
260 |
+
## REVIEW PAPER
|
261 |
+
if writeup == "latex":
|
262 |
+
try:
|
263 |
+
paper_text = load_paper(f"{folder_name}/{idea['Name']}.pdf")
|
264 |
+
if model == "gpt-4o-2024-05-13":
|
265 |
+
main_model = Model(model)
|
266 |
+
review = perform_review(
|
267 |
+
paper_text,
|
268 |
+
model=main_model,
|
269 |
+
client=openai.OpenAI(),
|
270 |
+
num_reflections=5,
|
271 |
+
num_fs_examples=1,
|
272 |
+
num_reviews_ensemble=5,
|
273 |
+
temperature=0.1,
|
274 |
+
)
|
275 |
+
elif model.startswith("ollama"):
|
276 |
+
# Use Ollama API for review generation
|
277 |
+
review = perform_review(
|
278 |
+
paper_text,
|
279 |
+
model=model.split("/")[-1],
|
280 |
+
client=openai.OpenAI(
|
281 |
+
api_key="ollama", base_url="http://localhost:11434/v1"
|
282 |
+
),
|
283 |
+
num_reflections=5,
|
284 |
+
num_fs_examples=1,
|
285 |
+
num_reviews_ensemble=5,
|
286 |
+
temperature=0.1,
|
287 |
+
)
|
288 |
+
# Store the review in separate review.txt file
|
289 |
+
with open(osp.join(folder_name, "review.txt"), "w") as f:
|
290 |
+
f.write(json.dumps(review, indent=4))
|
291 |
+
except Exception as e:
|
292 |
+
print(f"Failed to perform review: {e}")
|
293 |
+
return False
|
294 |
+
|
295 |
+
## IMPROVE WRITEUP
|
296 |
+
if writeup == "latex" and improvement:
|
297 |
+
print_time()
|
298 |
+
print(f"*Starting Improvement*")
|
299 |
+
try:
|
300 |
+
perform_improvement(review, coder)
|
301 |
+
generate_latex(
|
302 |
+
coder, folder_name, f"{folder_name}/{idea['Name']}_improved.pdf"
|
303 |
+
)
|
304 |
+
paper_text = load_paper(f"{folder_name}/{idea['Name']}_improved.pdf")
|
305 |
+
|
306 |
+
if model == "gpt-4o-2024-05-13":
|
307 |
+
main_model = Model(model)
|
308 |
+
review = perform_review(
|
309 |
+
paper_text,
|
310 |
+
model=main_model,
|
311 |
+
client=openai.OpenAI(),
|
312 |
+
num_reflections=5,
|
313 |
+
num_fs_examples=1,
|
314 |
+
num_reviews_ensemble=5,
|
315 |
+
temperature=0.1,
|
316 |
+
)
|
317 |
+
elif model.startswith("ollama"):
|
318 |
+
# Use Ollama API for review generation
|
319 |
+
review = perform_review(
|
320 |
+
paper_text,
|
321 |
+
model=model.split("/")[-1],
|
322 |
+
client=openai.OpenAI(
|
323 |
+
api_key="ollama", base_url="http://localhost:11434/v1"
|
324 |
+
),
|
325 |
+
num_reflections=5,
|
326 |
+
num_fs_examples=1,
|
327 |
+
num_reviews_ensemble=5,
|
328 |
+
temperature=0.1,
|
329 |
+
)
|
330 |
+
# Store the review in separate review.txt file
|
331 |
+
with open(osp.join(folder_name, "review_improved.txt"), "w") as f:
|
332 |
+
f.write(json.dumps(review))
|
333 |
+
except Exception as e:
|
334 |
+
print(f"Failed to perform improvement: {e}")
|
335 |
+
return False
|
336 |
+
return True
|
337 |
+
except Exception as e:
|
338 |
+
print(f"Failed to evaluate idea {idea_name}: {str(e)}")
|
339 |
+
return False
|
340 |
+
finally:
|
341 |
+
print("FINISHED IDEA")
|
342 |
+
if log_file:
|
343 |
+
sys.stdout = original_stdout
|
344 |
+
sys.stderr = original_stderr
|
345 |
+
log.close()
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
import traceback
|
350 |
+
try:
|
351 |
+
args = parse_arguments()
|
352 |
+
|
353 |
+
# Check available GPUs and adjust parallel processes if necessary
|
354 |
+
available_gpus = get_available_gpus(args.gpus)
|
355 |
+
if args.parallel > len(available_gpus):
|
356 |
+
print(
|
357 |
+
f"Warning: Requested {args.parallel} parallel processes, but only {len(available_gpus)} GPUs available. Adjusting to {len(available_gpus)}."
|
358 |
+
)
|
359 |
+
args.parallel = len(available_gpus)
|
360 |
+
|
361 |
+
print(f"Using GPUs: {available_gpus}")
|
362 |
+
|
363 |
+
# Create client
|
364 |
+
if args.model == "claude-3-5-sonnet-20240620":
|
365 |
+
import anthropic
|
366 |
+
|
367 |
+
print(f"Using Anthropic API with model {args.model}.")
|
368 |
+
client_model = "claude-3-5-sonnet-20240620"
|
369 |
+
client = anthropic.Anthropic()
|
370 |
+
elif args.model.startswith("bedrock") and "claude" in args.model:
|
371 |
+
import anthropic
|
372 |
+
|
373 |
+
# Expects: bedrock/<MODEL_ID>
|
374 |
+
client_model = args.model.split("/")[-1]
|
375 |
+
|
376 |
+
print(f"Using Amazon Bedrock with model {client_model}.")
|
377 |
+
client = anthropic.AnthropicBedrock(
|
378 |
+
aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"),
|
379 |
+
aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
380 |
+
aws_region=os.getenv("AWS_REGION_NAME"),
|
381 |
+
)
|
382 |
+
elif args.model.startswith("vertex_ai") and "claude" in args.model:
|
383 |
+
import anthropic
|
384 |
+
|
385 |
+
# Expects: vertex_ai/<MODEL_ID>
|
386 |
+
client_model = args.model.split("/")[-1]
|
387 |
+
|
388 |
+
print(f"Using Vertex AI with model {client_model}.")
|
389 |
+
client = anthropic.AnthropicVertex()
|
390 |
+
elif args.model == "gpt-4o-2024-05-13":
|
391 |
+
import openai
|
392 |
+
|
393 |
+
print(f"Using OpenAI API with model {args.model}.")
|
394 |
+
client_model = "gpt-4o-2024-05-13"
|
395 |
+
client = openai.OpenAI()
|
396 |
+
|
397 |
+
# ----------------------------------------------------
|
398 |
+
elif args.model == "Qwen/Qwen2.5-72B-Instruct":
|
399 |
+
# elif args.model.startswith("hyperbolic"):
|
400 |
+
print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
|
401 |
+
|
402 |
+
import openai
|
403 |
+
import os
|
404 |
+
# client_model = args.model[11:]
|
405 |
+
client_model = args.model
|
406 |
+
client = openai.OpenAI(
|
407 |
+
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
408 |
+
)
|
409 |
+
# ----------------------------------------------------
|
410 |
+
|
411 |
+
elif args.model.startswith("ollama"):
|
412 |
+
import openai
|
413 |
+
|
414 |
+
print(f"Using Ollama with {args.model}.")
|
415 |
+
client_model = args.model.split("/")[-1]
|
416 |
+
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
417 |
+
else:
|
418 |
+
raise ValueError(f"Model {args.model} not supported.")
|
419 |
+
|
420 |
+
base_dir = osp.join("templates", args.experiment)
|
421 |
+
results_dir = osp.join("results", args.experiment)
|
422 |
+
ideas = generate_ideas(
|
423 |
+
base_dir,
|
424 |
+
client=client,
|
425 |
+
model=client_model,
|
426 |
+
skip_generation=args.skip_idea_generation,
|
427 |
+
max_num_generations=args.num_ideas,
|
428 |
+
num_reflections=NUM_REFLECTIONS,
|
429 |
+
)
|
430 |
+
ideas = check_idea_novelty(
|
431 |
+
ideas,
|
432 |
+
base_dir=base_dir,
|
433 |
+
client=client,
|
434 |
+
model=client_model,
|
435 |
+
)
|
436 |
+
|
437 |
+
with open(osp.join(base_dir, "ideas.json"), "w") as f:
|
438 |
+
json.dump(ideas, f, indent=4)
|
439 |
+
|
440 |
+
novel_ideas = [idea for idea in ideas if idea["novel"]]
|
441 |
+
# novel_ideas = list(reversed(novel_ideas))
|
442 |
+
|
443 |
+
if args.parallel > 0:
|
444 |
+
print(f"Running {args.parallel} parallel processes")
|
445 |
+
queue = multiprocessing.Queue()
|
446 |
+
for idea in novel_ideas:
|
447 |
+
queue.put(idea)
|
448 |
+
|
449 |
+
processes = []
|
450 |
+
for i in range(args.parallel):
|
451 |
+
gpu_id = available_gpus[i % len(available_gpus)]
|
452 |
+
p = multiprocessing.Process(
|
453 |
+
target=worker,
|
454 |
+
args=(
|
455 |
+
queue,
|
456 |
+
base_dir,
|
457 |
+
results_dir,
|
458 |
+
args.model,
|
459 |
+
client,
|
460 |
+
client_model,
|
461 |
+
args.writeup,
|
462 |
+
args.improvement,
|
463 |
+
gpu_id,
|
464 |
+
),
|
465 |
+
)
|
466 |
+
p.start()
|
467 |
+
time.sleep(150)
|
468 |
+
processes.append(p)
|
469 |
+
|
470 |
+
# Signal workers to exit
|
471 |
+
for _ in range(args.parallel):
|
472 |
+
queue.put(None)
|
473 |
+
|
474 |
+
for p in processes:
|
475 |
+
p.join()
|
476 |
+
|
477 |
+
print("All parallel processes completed.")
|
478 |
+
else:
|
479 |
+
for idea in novel_ideas:
|
480 |
+
print(f"Processing idea: {idea['Name']}")
|
481 |
+
try:
|
482 |
+
success = do_idea(
|
483 |
+
base_dir,
|
484 |
+
results_dir,
|
485 |
+
idea,
|
486 |
+
args.model,
|
487 |
+
client,
|
488 |
+
client_model,
|
489 |
+
args.writeup,
|
490 |
+
args.improvement,
|
491 |
+
)
|
492 |
+
print(f"Completed idea: {idea['Name']}, Success: {success}")
|
493 |
+
except Exception as e:
|
494 |
+
print(f"Failed to evaluate idea {idea['Name']}: {str(e)}")
|
495 |
+
|
496 |
+
print("All ideas evaluated.")
|
497 |
+
|
498 |
+
except Exception as e:
|
499 |
+
print("error aya re baba")
|
500 |
+
traceback.print_exc()
|
DockerFile
ADDED
File without changes
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2020 Rémi Louf
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
Miniconda3-latest-Linux-x86_64.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d936ba600300e08eca3d874dee88c61c6f39303597b2b66baee54af4f7b4122
|
3 |
+
size 148337011
|
README.md
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">
|
2 |
+
<a href="https://github.com/SakanaAI/AI-Scientist/blob/main/docs/logo_2.png">
|
3 |
+
<img src="docs/logo_2.png" width="215" /></a><br>
|
4 |
+
<b>The AI Scientist: Towards Fully Automated</b><br>
|
5 |
+
<b>Open-Ended Scientific Discovery 🧑🔬</b><br>
|
6 |
+
</h1>
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
📚 <a href="https://arxiv.org/abs/2408.06292">[Paper]</a> |
|
10 |
+
📝 <a href="https://sakana.ai/ai-scientist/">[Blog Post]</a> |
|
11 |
+
📂 <a href="https://drive.google.com/drive/folders/1G7A0wTqfXVa-cpexjk0oaXakaSJwffEt">[Drive Folder]</a>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
One of the grand challenges of artificial intelligence is developing agents capable of conducting scientific research and discovering new knowledge. While frontier models have already been used to aid human scientists, e.g. for brainstorming ideas or writing code, they still require extensive manual supervision or are heavily constrained to a specific task.
|
15 |
+
|
16 |
+
We're excited to introduce **The AI Scientist**, the first comprehensive system for fully automatic scientific discovery, enabling Foundation Models such as Large Language Models (LLMs) to perform research independently.
|
17 |
+
|
18 |
+
We further provide all runs and data from our paper [here](https://drive.google.com/drive/folders/1G7A0wTqfXVa-cpexjk0oaXakaSJwffEt?usp=sharing), where we run each base model on each template for ~50 ideas. We _highly_ recommend reading through some of the [Claude papers](https://drive.google.com/drive/folders/1Mmpz6M1FK4q8e-SewgZcUzdeD0Q2zC39?usp=sharing), (especially the diffusion ones), to get a sense of its strengths and weaknesses. Here are some example papers generated by **The AI Scientist** 📝:
|
19 |
+
|
20 |
+
1. [DualScale Diffusion: Adaptive Feature Balancing for Low-Dimensional Generative Models](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/adaptive_dual_scale_denoising.pdf)
|
21 |
+
2. [Multi-scale Grid Noise Adaptation: Enhancing Diffusion Models For Low-dimensional Data](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/grid_based_noise_adaptation.pdf)
|
22 |
+
3. [GAN-Enhanced Diffusion: Boosting Sample Quality and Diversity](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/gan_diffusion.pdf)
|
23 |
+
4. [DualDiff: Enhancing Mode Capture in Low-dimensional Diffusion Models via Dual-expert Denoising](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/dual_expert_denoiser.pdf)
|
24 |
+
5. [StyleFusion: Adaptive Multi-style Generation in Character-Level Language Models](https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/multi_style_adapter.pdf)
|
25 |
+
6. [Adaptive Learning Rates for Transformers via Q-Learning](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/rl_lr_adaptation.pdf)
|
26 |
+
7. [Unlocking Grokking: A Comparative Study of Weight Initialization Strategies in Transformer Models](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/weight_initialization_grokking.pdf)
|
27 |
+
8. [Grokking Accelerated: Layer-wise Learning Rates for Transformer Generalization](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/layerwise_lr_grokking.pdf)
|
28 |
+
9. [Grokking Through Compression: Unveiling Sudden Generalization via Minimal Description Length](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/mdl_grokking_correlation.pdf)
|
29 |
+
10. [Accelerating Mathematical Insight: Boosting Grokking Through Strategic Data Augmentation](https://github.com/SakanaAI/AI-Scientist/tree/main/example_papers/data_augmentation_grokking.pdf)
|
30 |
+
|
31 |
+
**Note**: Caution! This codebase will execute LLM-written code. There are various risks and challenges associated with this autonomy. This includes e.g. the use of potentially dangerous packages, web access, and potential spawning of processes. Use at your own discretion. Please make sure to [containerize](#containerization) and restrict web access appropriately.
|
32 |
+
|
33 |
+
<p align="center">
|
34 |
+
<a href="https://github.com/SakanaAI/AI-Scientist/blob/main/example_papers/adaptive_dual_scale_denoising/adaptive_dual_scale_denoising.pdf"><img src="https://github.com/SakanaAI/AI-Scientist/blob/main/docs/anim-ai-scientist.gif" alt="Adaptive Dual Scale Denoising" width="80%" />
|
35 |
+
</p>
|
36 |
+
|
37 |
+
## Table of Contents
|
38 |
+
|
39 |
+
1. [Requirements](#requirements)
|
40 |
+
2. [Run AI Scientist Paper Generation Experiments](#run-ai-scientist-paper-generation-experiments)
|
41 |
+
3. [Getting an LLM Generated Paper Review](#getting-an-llm-generated-paper-review)
|
42 |
+
4. [Making your own Template](#making-your-own-template)
|
43 |
+
5. [Template Resources](#template-resources)
|
44 |
+
6. [Citing The AI Scientist](#citing-the-ai-scientist)
|
45 |
+
7. [Frequently Asked Questions](#faq)
|
46 |
+
8. [Containerization](#containerization)
|
47 |
+
|
48 |
+
## Requirements
|
49 |
+
|
50 |
+
This code was designed for NVIDIA GPUs with CUDA using PyTorch. Support for other GPU architectures may be possible by following [PyTorch guidelines](https://pytorch.org/get-started/locally/). Current templates would likely take an infeasible amount of time on CPU-only machines. All code is designed to be run on Linux, other operating systems will likely require major adjustments.
|
51 |
+
|
52 |
+
### Installation
|
53 |
+
|
54 |
+
```bash
|
55 |
+
conda create -n ai_scientist python=3.11
|
56 |
+
conda activate ai_scientist
|
57 |
+
# Install pdflatex
|
58 |
+
sudo apt-get install texlive-full
|
59 |
+
|
60 |
+
# Install pypi requirements
|
61 |
+
pip install -r requirements.txt
|
62 |
+
```
|
63 |
+
|
64 |
+
When installing `texlive-full`, you may need to [hold Enter](https://askubuntu.com/questions/956006/pregenerating-context-markiv-format-this-may-take-some-time-takes-forever).
|
65 |
+
|
66 |
+
### Supported Models and API Keys
|
67 |
+
|
68 |
+
We support a wide variety of models including open-weight and API-only models. In general, we recommend only using frontier models above the capability of the original GPT-4.
|
69 |
+
|
70 |
+
#### OpenAI API (GPT-4)
|
71 |
+
|
72 |
+
By default, this uses the `OPENAI_API_KEY` environment variable.
|
73 |
+
|
74 |
+
#### Anthropic API (Claude Sonnet 3.5)
|
75 |
+
|
76 |
+
By default, this uses the `ANTHROPIC_API_KEY` environment variable.
|
77 |
+
|
78 |
+
##### Claude models via Bedrock
|
79 |
+
|
80 |
+
For Claude models provided by [Amazon Bedrock](https://aws.amazon.com/bedrock/), please install these additional packages:
|
81 |
+
|
82 |
+
```bash
|
83 |
+
pip install anthropic[bedrock]
|
84 |
+
```
|
85 |
+
|
86 |
+
Next, specify a set of valid [AWS Credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html) and the target [AWS Region](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-regions.html):
|
87 |
+
|
88 |
+
Set these environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION_NAME`.
|
89 |
+
|
90 |
+
##### Claude models via Vertex AI
|
91 |
+
|
92 |
+
For Claude models provided by [Vertex AI Model Garden](https://cloud.google.com/model-garden?hl=en), please install these additional packages:
|
93 |
+
|
94 |
+
```bash
|
95 |
+
pip install google-cloud-aiplatform
|
96 |
+
pip install anthropic[vertex]
|
97 |
+
```
|
98 |
+
|
99 |
+
Next, set up a valid authentication for a [Google Cloud project](https://cloud.google.com/vertex-ai/docs/authentication), for example by providing region and project ID like so:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
export CLOUD_ML_REGION="REGION" # for Model Garden call
|
103 |
+
export ANTHROPIC_VERTEX_PROJECT_ID="PROJECT_ID" # for Model Garden call
|
104 |
+
export VERTEXAI_LOCATION="REGION" # for Aider/LiteLLM call, as per https://docs.litellm.ai/docs/providers/vertex#set-vertex-project--vertex-location
|
105 |
+
export VERTEXAI_PROJECT="PROJECT_ID" # for Aider/LiteLLM call as per https://docs.litellm.ai/docs/providers/vertex#set-vertex-project--vertex-location
|
106 |
+
```
|
107 |
+
|
108 |
+
#### DeepSeek API (DeepSeek-Coder-V2)
|
109 |
+
|
110 |
+
By default, this uses the `DEEPSEEK_API_KEY` environment variable.
|
111 |
+
|
112 |
+
#### OpenRouter API (Llama3.1)
|
113 |
+
|
114 |
+
By default, this uses the `OPENROUTER_API_KEY` environment variable.
|
115 |
+
|
116 |
+
#### Semantic Scholar API (Literature Search)
|
117 |
+
|
118 |
+
Our code can also optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher throughput [if you have one](https://www.semanticscholar.org/product/api), though in principle it should work without it.
|
119 |
+
|
120 |
+
Be sure to provide the key for the model used for your runs, e.g.
|
121 |
+
|
122 |
+
```bash
|
123 |
+
export OPENAI_API_KEY="YOUR KEY HERE"
|
124 |
+
export S2_API_KEY="YOUR KEY HERE"
|
125 |
+
```
|
126 |
+
|
127 |
+
### Setup NanoGPT
|
128 |
+
|
129 |
+
Here, and below, we give instructions for setting up the data and baseline evaluations for each template. You can only run setup steps for templates you are interested in. This is necessary to run on your machine as training times may vary depending on your hardware.
|
130 |
+
|
131 |
+
```bash
|
132 |
+
# Prepare NanoGPT data
|
133 |
+
python data/enwik8/prepare.py
|
134 |
+
python data/shakespeare_char/prepare.py
|
135 |
+
python data/text8/prepare.py
|
136 |
+
```
|
137 |
+
|
138 |
+
#### Create baseline runs (machine dependent)
|
139 |
+
|
140 |
+
```bash
|
141 |
+
# Set up NanoGPT baseline run
|
142 |
+
# NOTE: YOU MUST FIRST RUN THE PREPARE SCRIPTS ABOVE!
|
143 |
+
cd templates/nanoGPT && python experiment.py --out_dir run_0 && python plot.py
|
144 |
+
```
|
145 |
+
|
146 |
+
#### Create NanoGPT_lite baseline run. We use this for sanity-checking
|
147 |
+
|
148 |
+
```bash
|
149 |
+
# NOTE: YOU MUST FIRST RUN THE PREPARE SCRIPTS ABOVE!
|
150 |
+
cd templates/nanoGPT_lite && python experiment.py --out_dir run_0 && python plot.py
|
151 |
+
```
|
152 |
+
|
153 |
+
### Setup 2D Diffusion
|
154 |
+
|
155 |
+
```bash
|
156 |
+
# Set up 2D Diffusion
|
157 |
+
git clone https://github.com/gregversteeg/NPEET.git
|
158 |
+
cd NPEET
|
159 |
+
pip install .
|
160 |
+
pip install scikit-learn
|
161 |
+
|
162 |
+
# Set up 2D Diffusion baseline run
|
163 |
+
cd templates/2d_diffusion && python experiment.py --out_dir run_0 && python plot.py
|
164 |
+
```
|
165 |
+
|
166 |
+
### Setup Grokking
|
167 |
+
|
168 |
+
```bash
|
169 |
+
# Set up Grokking
|
170 |
+
pip install einops
|
171 |
+
|
172 |
+
# Set up Grokking baseline run
|
173 |
+
cd templates/grokking && python experiment.py --out_dir run_0 && python plot.py
|
174 |
+
```
|
175 |
+
|
176 |
+
## Run AI Scientist Paper Generation Experiments
|
177 |
+
|
178 |
+
**Note:** please ensure the setup steps above are completed.
|
179 |
+
|
180 |
+
```bash
|
181 |
+
conda activate ai_scientist
|
182 |
+
# Run the paper generation.
|
183 |
+
python launch_scientist.py --model "gpt-4o-2024-05-13" --experiment nanoGPT_lite --num-ideas 2
|
184 |
+
python launch_scientist.py --model "claude-3-5-sonnet-20240620" --experiment nanoGPT_lite --num-ideas 2
|
185 |
+
python launch_scientist.py --model "ollama/mistral-nemo" --experiment nanoGPT_lite --num-ideas 2
|
186 |
+
```
|
187 |
+
|
188 |
+
If you have more than 1 GPU, use the `parallel` option to parallelize ideas across multiple GPUs.
|
189 |
+
|
190 |
+
## Getting an LLM Generated Paper Review
|
191 |
+
|
192 |
+
```python
|
193 |
+
import openai
|
194 |
+
from ai_scientist.perform_review import load_paper, perform_review
|
195 |
+
|
196 |
+
client = openai.OpenAI()
|
197 |
+
model = "gpt-4o-2024-05-13"
|
198 |
+
|
199 |
+
# Load paper from pdf file (raw text)
|
200 |
+
paper_txt = load_paper("report.pdf")
|
201 |
+
# Get the review dict of the review
|
202 |
+
review = perform_review(
|
203 |
+
paper_txt,
|
204 |
+
model,
|
205 |
+
client,
|
206 |
+
num_reflections=5,
|
207 |
+
num_fs_examples=1,
|
208 |
+
num_reviews_ensemble=5,
|
209 |
+
temperature=0.1,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Inspect review results
|
213 |
+
review["Overall"] # overall score 1-10
|
214 |
+
review["Decision"] # ['Accept', 'Reject']
|
215 |
+
review["Weaknesses"] # List of weaknesses (str)
|
216 |
+
```
|
217 |
+
|
218 |
+
To run batch analysis:
|
219 |
+
|
220 |
+
```bash
|
221 |
+
cd review_iclr_bench
|
222 |
+
python iclr_analysis.py --num_reviews 500 --batch_size 100 --num_fs_examples 1 --num_reflections 5 --temperature 0.1 --num_reviews_ensemble 5
|
223 |
+
```
|
224 |
+
|
225 |
+
## Making your own Template
|
226 |
+
|
227 |
+
If there is an area of study you would like **The AI Scientist** to explore, it should be very easy to create your own templates. In general, follow the structure of the existing templates, which consists of:
|
228 |
+
|
229 |
+
- `experiment.py` -- This is a single file where the 'meat' of the content is. It takes in an argument for `out_dir`, which is where it should create the folder and save the relevant information from the run.
|
230 |
+
- `plot.py` -- This should take in the information from the `run` folders and create plots. The code should be clear and easy to edit.
|
231 |
+
- `prompt.json` -- Put information about your template here.
|
232 |
+
- `seed_ideas.json` -- Put example ideas here. You can also try to generate ideas without any examples, and then pick the best one or two to put here.
|
233 |
+
- `latex/template.tex` -- We recommend using our latex folder, but be sure to replace the pre-loaded citations with ones that you would expect to be more relevant.
|
234 |
+
|
235 |
+
## Template Resources
|
236 |
+
|
237 |
+
We provide 3 templates, which heavily use code from other repositories, which we credit below. (Normally, we would do this in the files themselves, but it's unclear how this would affect The AI Scientist since it would be visible).
|
238 |
+
|
239 |
+
The NanoGPT template used code from [NanoGPT](https://github.com/karpathy/nanoGPT) and this [PR](https://github.com/karpathy/nanoGPT/pull/254).
|
240 |
+
|
241 |
+
The 2D Diffusion template used code from [tiny-diffusion](https://github.com/tanelp/tiny-diffusion), [ema-pytorch](https://github.com/lucidrains/ema-pytorch), and [Datasaur](https://www.research.autodesk.com/publications/same-stats-different-graphs/).
|
242 |
+
|
243 |
+
The Grokking template used code from [Sea-Snell/grokking](https://github.com/Sea-Snell/grokking) and [danielmamay/grokking](https://github.com/danielmamay/grokking).
|
244 |
+
|
245 |
+
We would like to thank the developers of the open-source models and packages for their contributions and for making their work available.
|
246 |
+
|
247 |
+
## Citing The AI Scientist
|
248 |
+
|
249 |
+
If you use **The AI Scientist** in your research, please cite it as follows:
|
250 |
+
|
251 |
+
```
|
252 |
+
@article{lu2024aiscientist,
|
253 |
+
title={The {AI} {S}cientist: Towards Fully Automated Open-Ended Scientific Discovery},
|
254 |
+
author={Lu, Chris and Lu, Cong and Lange, Robert Tjarko and Foerster, Jakob and Clune, Jeff and Ha, David},
|
255 |
+
journal={arXiv preprint arXiv:2408.06292},
|
256 |
+
year={2024}
|
257 |
+
}
|
258 |
+
```
|
259 |
+
|
260 |
+
## FAQ
|
261 |
+
|
262 |
+
We recommend reading our paper in the first instance for any questions you have on The AI Scientist.
|
263 |
+
|
264 |
+
### Why am I missing files when running The AI Scientist?
|
265 |
+
|
266 |
+
Make sure you have completed all the setup and preparation steps before the main experiment script.
|
267 |
+
|
268 |
+
### Why has a PDF or a review not been generated?
|
269 |
+
|
270 |
+
The AI Scientist finishes an idea with a success rate that depends on both the template, the base foundation model, and the complexity of the idea. We advise referring to our main paper. The highest success rates are observed with Claude Sonnet 3.5.
|
271 |
+
Reviews are best done with GPT-4o, all other models have issues with positivity bias or failure to conform to required outputs.
|
272 |
+
|
273 |
+
### What is the cost of each idea generated?
|
274 |
+
|
275 |
+
Typically less than $15 per paper with Claude Sonnet 3.5. We recommend DeepSeek Coder V2 for a much more cost-effective approach. A good place to look for new models is the [Aider leaderboard](https://aider.chat/docs/leaderboards/).
|
276 |
+
|
277 |
+
### How do I change the base conference format associated with the write-ups?
|
278 |
+
|
279 |
+
Change the base `template.tex` files contained within each template.
|
280 |
+
|
281 |
+
### How do I run The AI Scientist for different subject fields?
|
282 |
+
|
283 |
+
Please refer to the instructions for different templates. In this current iteration, this is restricted to ideas that can be expressed in code. However, lifting this restriction would represent exciting future work! :)
|
284 |
+
|
285 |
+
### How do I add support for a new foundation model?
|
286 |
+
|
287 |
+
Please see this [PR](https://github.com/SakanaAI/AI-Scientist/pull/7) for an example of how to add a new model, e.g. this time for Claude via Bedrock.
|
288 |
+
We do not advise any model that is significantly weaker than GPT-4 level for The AI Scientist.
|
289 |
+
|
290 |
+
### Why do I need to run the baseline runs myself?
|
291 |
+
These appear as `run_0` and should be run per machine you execute The AI Scientist on for accurate run-time comparisons due to hardware differences.
|
292 |
+
|
293 |
+
## Containerization
|
294 |
+
|
295 |
+
We include a [community-contributed](https://github.com/SakanaAI/AI-Scientist/pull/21) Docker image that may assist with your containerization efforts in `experimental/Dockerfile`.
|
296 |
+
|
297 |
+
You can use this image like this:
|
298 |
+
|
299 |
+
```bash
|
300 |
+
# Endpoint Script
|
301 |
+
docker run -e OPENAI_API_KEY=$OPENAI_API_KEY -v `pwd`/templates:/app/AI-Scientist/templates <AI_SCIENTIST_IMAGE> \
|
302 |
+
--model gpt-4o-2024-05-13 \
|
303 |
+
--experiment 2d_diffusion \
|
304 |
+
--num-ideas 2
|
305 |
+
```
|
306 |
+
|
307 |
+
```bash
|
308 |
+
# Interactive
|
309 |
+
docker run -it -e OPENAI_API_KEY=$OPENAI_API_KEY \
|
310 |
+
--entrypoint /bin/bash \
|
311 |
+
<AI_SCIENTIST_IMAGE>
|
312 |
+
```
|
ai_scientist/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
ai_scientist/.ipynb_checkpoints/generate_ideas-checkpoint.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import time
|
5 |
+
from typing import Dict, List, Union
|
6 |
+
|
7 |
+
import backoff
|
8 |
+
import requests
|
9 |
+
from strictjson import strict_json
|
10 |
+
|
11 |
+
from ai_scientist.llm import (
|
12 |
+
allchoices,
|
13 |
+
extract_json_between_markers,
|
14 |
+
get_response_from_llm,
|
15 |
+
llm_json_auto_correct,
|
16 |
+
)
|
17 |
+
|
18 |
+
S2_API_KEY = os.getenv("S2_API_KEY")
|
19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
20 |
+
|
21 |
+
|
22 |
+
idea_first_prompt = """{task_description}
|
23 |
+
<experiment.py>
|
24 |
+
{code}
|
25 |
+
</experiment.py>
|
26 |
+
|
27 |
+
Here are the ideas that you have already generated:
|
28 |
+
|
29 |
+
'''
|
30 |
+
{prev_ideas_string}
|
31 |
+
'''
|
32 |
+
|
33 |
+
Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
|
34 |
+
Note that you will not have access to any additional resources or datasets.
|
35 |
+
Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
|
36 |
+
|
37 |
+
Respond in the following format:
|
38 |
+
|
39 |
+
THOUGHT:
|
40 |
+
<THOUGHT>
|
41 |
+
|
42 |
+
NEW IDEA JSON:
|
43 |
+
```json
|
44 |
+
<JSON>
|
45 |
+
```
|
46 |
+
|
47 |
+
In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
|
48 |
+
|
49 |
+
Add '```json' before the <JSON> and '```' after the <JSON> as above. In <JSON>, provide the new idea in JSON format with the following keys and values:
|
50 |
+
- "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
|
51 |
+
- "Title": A title for the idea, will be used for the report writing.
|
52 |
+
- "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
|
53 |
+
- "Interestingness": A rating from 1 to 10 (lowest to highest).
|
54 |
+
- "Feasibility": A rating from 1 to 10 (lowest to highest).
|
55 |
+
- "Novelty": A rating from 1 to 10 (lowest to highest).
|
56 |
+
|
57 |
+
Be cautious and realistic on your ratings.
|
58 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
59 |
+
You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
|
60 |
+
"""
|
61 |
+
|
62 |
+
idea_reflection_prompt = """Round {current_round}/{num_reflections}.
|
63 |
+
In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
|
64 |
+
Include any other factors that you think are important in evaluating the idea.
|
65 |
+
Ensure the idea is clear and concise, and the JSON is the correct format.
|
66 |
+
Do not make things overly complicated.
|
67 |
+
In the next attempt, try and refine and improve your idea.
|
68 |
+
Stick to the spirit of the original idea unless there are glaring issues.
|
69 |
+
|
70 |
+
Respond in the exactly the same format as before:
|
71 |
+
THOUGHT:
|
72 |
+
<THOUGHT>
|
73 |
+
|
74 |
+
NEW IDEA JSON:
|
75 |
+
```json
|
76 |
+
<JSON>
|
77 |
+
```
|
78 |
+
|
79 |
+
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
|
80 |
+
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES.
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
# Format the content in JSON
|
85 |
+
def format_idea_json(text):
|
86 |
+
json_start_marker = "```json"
|
87 |
+
json_end_marker = "```"
|
88 |
+
start_index = text.find(json_start_marker)
|
89 |
+
if start_index != -1:
|
90 |
+
start_index += len(json_start_marker) # Move past the marker
|
91 |
+
end_index = text.find(json_end_marker, start_index)
|
92 |
+
json_string = text[start_index:end_index].strip()
|
93 |
+
res = strict_json(
|
94 |
+
system_prompt="You are a JSON formatter",
|
95 |
+
user_prompt=json_string,
|
96 |
+
output_format={
|
97 |
+
"Name": "A shortened descriptor of the idea",
|
98 |
+
"Title": "A title for the idea, will be used for the report writing",
|
99 |
+
"Experiment": "An outline of the implementation, type: list",
|
100 |
+
"Interestingness": "A rating from 1 to 10 (lowest to highest), type: int",
|
101 |
+
"Feasibility": "A rating from 1 to 10 (lowest to highest), type: int",
|
102 |
+
"Novelty": "A rating from 1 to 10 (lowest to highest), type: int",
|
103 |
+
},
|
104 |
+
llm=llm_json_auto_correct,
|
105 |
+
)
|
106 |
+
text = "```json\n" + json.dumps(res) + "```\n"
|
107 |
+
return text
|
108 |
+
|
109 |
+
|
110 |
+
def format_novelty_json(text):
|
111 |
+
json_start_marker = "```json"
|
112 |
+
json_end_marker = "```"
|
113 |
+
start_index = text.find(json_start_marker)
|
114 |
+
if start_index != -1:
|
115 |
+
start_index += len(json_start_marker) # Move past the marker
|
116 |
+
end_index = text.find(json_end_marker, start_index)
|
117 |
+
json_string = text[start_index:end_index].strip()
|
118 |
+
res = strict_json(
|
119 |
+
system_prompt="You are a JSON formatter",
|
120 |
+
user_prompt=json_string,
|
121 |
+
output_format={
|
122 |
+
"Query": "An optional search query to search the literature (e.g. attention is all you need)",
|
123 |
+
},
|
124 |
+
llm=llm_json_auto_correct,
|
125 |
+
)
|
126 |
+
text = "```json\n" + json.dumps(res) + "```\n"
|
127 |
+
return text
|
128 |
+
|
129 |
+
|
130 |
+
# GENERATE IDEAS
|
131 |
+
def generate_ideas(
|
132 |
+
base_dir,
|
133 |
+
client,
|
134 |
+
model,
|
135 |
+
skip_generation=False,
|
136 |
+
max_num_generations=20,
|
137 |
+
num_reflections=5,
|
138 |
+
):
|
139 |
+
if skip_generation:
|
140 |
+
# Load existing ideas from file
|
141 |
+
try:
|
142 |
+
with open(osp.join(base_dir, "ideas.json"), "r") as f:
|
143 |
+
ideas = json.load(f)
|
144 |
+
print("Loaded existing ideas:")
|
145 |
+
for idea in ideas:
|
146 |
+
print(idea)
|
147 |
+
return ideas
|
148 |
+
except FileNotFoundError:
|
149 |
+
print("No existing ideas found. Generating new ideas.")
|
150 |
+
except json.JSONDecodeError:
|
151 |
+
print("Error decoding existing ideas. Generating new ideas.")
|
152 |
+
|
153 |
+
idea_str_archive = []
|
154 |
+
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
|
155 |
+
seed_ideas = json.load(f)
|
156 |
+
for seed_idea in seed_ideas:
|
157 |
+
idea_str_archive.append(json.dumps(seed_idea))
|
158 |
+
|
159 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
160 |
+
code = f.read()
|
161 |
+
|
162 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
163 |
+
prompt = json.load(f)
|
164 |
+
|
165 |
+
idea_system_prompt = prompt["system"]
|
166 |
+
|
167 |
+
for _ in range(max_num_generations):
|
168 |
+
print()
|
169 |
+
print(f"Generating idea {_ + 1}/{max_num_generations}")
|
170 |
+
import traceback
|
171 |
+
try:
|
172 |
+
prev_ideas_string = "\n\n".join(idea_str_archive)
|
173 |
+
|
174 |
+
msg_history = []
|
175 |
+
print(f"Iteration 1/{num_reflections}")
|
176 |
+
text, msg_history = get_response_from_llm(
|
177 |
+
idea_first_prompt.format(
|
178 |
+
task_description=prompt["task_description"],
|
179 |
+
code=code,
|
180 |
+
prev_ideas_string=prev_ideas_string,
|
181 |
+
num_reflections=num_reflections,
|
182 |
+
),
|
183 |
+
client=client,
|
184 |
+
model=model,
|
185 |
+
system_message=idea_system_prompt,
|
186 |
+
msg_history=msg_history,
|
187 |
+
)
|
188 |
+
## Format the content in JSON
|
189 |
+
text = format_idea_json(text)
|
190 |
+
|
191 |
+
## PARSE OUTPUT
|
192 |
+
json_output = extract_json_between_markers(text)
|
193 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
194 |
+
# print(json_output)
|
195 |
+
|
196 |
+
# Iteratively improve task.
|
197 |
+
if num_reflections > 1:
|
198 |
+
for j in range(num_reflections - 1):
|
199 |
+
print(f"Iteration {j + 2}/{num_reflections}")
|
200 |
+
text, msg_history = get_response_from_llm(
|
201 |
+
idea_reflection_prompt.format(
|
202 |
+
current_round=j + 2, num_reflections=num_reflections
|
203 |
+
),
|
204 |
+
client=client,
|
205 |
+
model=model,
|
206 |
+
system_message=idea_system_prompt,
|
207 |
+
msg_history=msg_history,
|
208 |
+
)
|
209 |
+
## Format the content in JSON if using weak LLM
|
210 |
+
text = format_idea_json(text)
|
211 |
+
## PARSE OUTPUT
|
212 |
+
json_output = extract_json_between_markers(text)
|
213 |
+
assert (
|
214 |
+
json_output is not None
|
215 |
+
), "Failed to extract JSON from LLM output"
|
216 |
+
# print(json_output)
|
217 |
+
|
218 |
+
if "I am done" in text:
|
219 |
+
print(f"Idea generation converged after {j + 2} iterations.")
|
220 |
+
break
|
221 |
+
|
222 |
+
idea_str_archive.append(json.dumps(json_output))
|
223 |
+
except Exception as e:
|
224 |
+
print(f"Failed to generate idea: {e}")
|
225 |
+
traceback.print_exc()
|
226 |
+
continue
|
227 |
+
|
228 |
+
## SAVE IDEAS
|
229 |
+
ideas = []
|
230 |
+
for idea_str in idea_str_archive:
|
231 |
+
ideas.append(json.loads(idea_str))
|
232 |
+
|
233 |
+
with open(osp.join(base_dir, "ideas.json"), "w") as f:
|
234 |
+
json.dump(ideas, f, indent=4)
|
235 |
+
|
236 |
+
return ideas
|
237 |
+
|
238 |
+
|
239 |
+
# GENERATE IDEAS OPEN-ENDED
|
240 |
+
def generate_next_idea(
|
241 |
+
base_dir,
|
242 |
+
client,
|
243 |
+
model,
|
244 |
+
prev_idea_archive=[],
|
245 |
+
num_reflections=5,
|
246 |
+
max_attempts=10,
|
247 |
+
):
|
248 |
+
idea_archive = prev_idea_archive
|
249 |
+
original_archive_size = len(idea_archive)
|
250 |
+
|
251 |
+
print(f"Generating idea {original_archive_size + 1}")
|
252 |
+
|
253 |
+
if len(prev_idea_archive) == 0:
|
254 |
+
print(f"First iteration, taking seed ideas")
|
255 |
+
# seed the archive on the first run with pre-existing ideas
|
256 |
+
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
|
257 |
+
seed_ideas = json.load(f)
|
258 |
+
for seed_idea in seed_ideas[:1]:
|
259 |
+
idea_archive.append(seed_idea)
|
260 |
+
else:
|
261 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
262 |
+
code = f.read()
|
263 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
264 |
+
prompt = json.load(f)
|
265 |
+
idea_system_prompt = prompt["system"]
|
266 |
+
|
267 |
+
for _ in range(max_attempts):
|
268 |
+
import traceback
|
269 |
+
try:
|
270 |
+
idea_strings = []
|
271 |
+
for idea in idea_archive:
|
272 |
+
idea_strings.append(json.dumps(idea))
|
273 |
+
prev_ideas_string = "\n\n".join(idea_strings)
|
274 |
+
|
275 |
+
msg_history = []
|
276 |
+
print(f"Iteration 1/{num_reflections}")
|
277 |
+
text, msg_history = get_response_from_llm(
|
278 |
+
idea_first_prompt.format(
|
279 |
+
task_description=prompt["task_description"],
|
280 |
+
code=code,
|
281 |
+
prev_ideas_string=prev_ideas_string,
|
282 |
+
num_reflections=num_reflections,
|
283 |
+
)
|
284 |
+
+ """
|
285 |
+
Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
|
286 |
+
This is on a standard 1-10 ML conference scale.
|
287 |
+
Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
|
288 |
+
""",
|
289 |
+
client=client,
|
290 |
+
model=model,
|
291 |
+
system_message=idea_system_prompt,
|
292 |
+
msg_history=msg_history,
|
293 |
+
)
|
294 |
+
## Format the content in JSON if using weak LLM
|
295 |
+
text = format_idea_json(text)
|
296 |
+
## PARSE OUTPUT
|
297 |
+
json_output = extract_json_between_markers(text)
|
298 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
299 |
+
# print(json_output)
|
300 |
+
|
301 |
+
# Iteratively improve task.
|
302 |
+
if num_reflections > 1:
|
303 |
+
for j in range(num_reflections - 1):
|
304 |
+
print(f"Iteration {j + 2}/{num_reflections}")
|
305 |
+
text, msg_history = get_response_from_llm(
|
306 |
+
idea_reflection_prompt.format(
|
307 |
+
current_round=j + 2, num_reflections=num_reflections
|
308 |
+
),
|
309 |
+
client=client,
|
310 |
+
model=model,
|
311 |
+
system_message=idea_system_prompt,
|
312 |
+
msg_history=msg_history,
|
313 |
+
)
|
314 |
+
## Format the content in JSON if using weak LLM
|
315 |
+
text = format_idea_json(text)
|
316 |
+
## PARSE OUTPUT
|
317 |
+
json_output = extract_json_between_markers(text)
|
318 |
+
assert (
|
319 |
+
json_output is not None
|
320 |
+
), "Failed to extract JSON from LLM output"
|
321 |
+
# print(json_output)
|
322 |
+
|
323 |
+
if "I am done" in text:
|
324 |
+
print(
|
325 |
+
f"Idea generation converged after {j + 2} iterations."
|
326 |
+
)
|
327 |
+
break
|
328 |
+
|
329 |
+
idea_archive.append(json_output)
|
330 |
+
break
|
331 |
+
except Exception as e:
|
332 |
+
print(f"Failed to generate idea: {e}")
|
333 |
+
traceback.print_exc()
|
334 |
+
continue
|
335 |
+
|
336 |
+
## SAVE IDEAS
|
337 |
+
with open(osp.join(base_dir, "ideas.json"), "w") as f:
|
338 |
+
json.dump(idea_archive, f, indent=4)
|
339 |
+
|
340 |
+
return idea_archive
|
341 |
+
|
342 |
+
|
343 |
+
def on_backoff(details):
|
344 |
+
print(
|
345 |
+
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
|
346 |
+
f"calling function {details['target'].__name__} at {time.strftime('%X')}"
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
@backoff.on_exception(
|
351 |
+
backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
|
352 |
+
)
|
353 |
+
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
|
354 |
+
if not query:
|
355 |
+
return None
|
356 |
+
rsp = requests.get(
|
357 |
+
"https://api.semanticscholar.org/graph/v1/paper/search",
|
358 |
+
headers={"X-API-KEY": S2_API_KEY},
|
359 |
+
params={
|
360 |
+
"query": query,
|
361 |
+
"limit": result_limit,
|
362 |
+
"fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
|
363 |
+
},
|
364 |
+
)
|
365 |
+
print(f"Response Status Code: {rsp.status_code}")
|
366 |
+
print(
|
367 |
+
f"Response Content: {rsp.text[:500]}"
|
368 |
+
) # Print the first 500 characters of the response content
|
369 |
+
rsp.raise_for_status()
|
370 |
+
results = rsp.json()
|
371 |
+
total = results["total"]
|
372 |
+
if not total:
|
373 |
+
return None
|
374 |
+
time.sleep(2)
|
375 |
+
papers = results["data"]
|
376 |
+
return papers
|
377 |
+
|
378 |
+
|
379 |
+
novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
|
380 |
+
You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
|
381 |
+
Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
|
382 |
+
You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
|
383 |
+
The top 10 results for any search query will be presented to you with the abstracts.
|
384 |
+
|
385 |
+
You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
|
386 |
+
At any round, you may exit early and decide on the novelty of the idea.
|
387 |
+
Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
|
388 |
+
Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
|
389 |
+
|
390 |
+
{task_description}
|
391 |
+
<experiment.py>
|
392 |
+
{code}
|
393 |
+
</experiment.py>
|
394 |
+
"""
|
395 |
+
|
396 |
+
novelty_prompt = '''Round {current_round}/{num_rounds}.
|
397 |
+
You have this idea:
|
398 |
+
|
399 |
+
"""
|
400 |
+
{idea}
|
401 |
+
"""
|
402 |
+
|
403 |
+
The results of the last query are (empty on first round):
|
404 |
+
"""
|
405 |
+
{last_query_results}
|
406 |
+
"""
|
407 |
+
|
408 |
+
Respond in the following format:
|
409 |
+
|
410 |
+
THOUGHT:
|
411 |
+
<THOUGHT>
|
412 |
+
|
413 |
+
RESPONSE:
|
414 |
+
```json
|
415 |
+
<JSON>
|
416 |
+
```
|
417 |
+
|
418 |
+
In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision.
|
419 |
+
If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
|
420 |
+
|
421 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
422 |
+
- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
|
423 |
+
|
424 |
+
A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
|
425 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
426 |
+
'''
|
427 |
+
|
428 |
+
|
429 |
+
def check_idea_novelty(
|
430 |
+
ideas,
|
431 |
+
base_dir,
|
432 |
+
client,
|
433 |
+
model,
|
434 |
+
max_num_iterations=10,
|
435 |
+
):
|
436 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
437 |
+
code = f.read()
|
438 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
439 |
+
prompt = json.load(f)
|
440 |
+
task_description = prompt["task_description"]
|
441 |
+
|
442 |
+
for idx, idea in enumerate(ideas):
|
443 |
+
if "novel" in idea:
|
444 |
+
print(f"Skipping idea {idx}, already checked.")
|
445 |
+
continue
|
446 |
+
|
447 |
+
print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
|
448 |
+
|
449 |
+
novel = False
|
450 |
+
msg_history = []
|
451 |
+
papers_str = ""
|
452 |
+
|
453 |
+
for j in range(max_num_iterations):
|
454 |
+
try:
|
455 |
+
text, msg_history = get_response_from_llm(
|
456 |
+
novelty_prompt.format(
|
457 |
+
current_round=j + 1,
|
458 |
+
num_rounds=max_num_iterations,
|
459 |
+
idea=idea,
|
460 |
+
last_query_results=papers_str,
|
461 |
+
),
|
462 |
+
client=client,
|
463 |
+
model=model,
|
464 |
+
system_message=novelty_system_msg.format(
|
465 |
+
num_rounds=max_num_iterations,
|
466 |
+
task_description=task_description,
|
467 |
+
code=code,
|
468 |
+
),
|
469 |
+
msg_history=msg_history,
|
470 |
+
)
|
471 |
+
if "decision made: novel" in text.lower():
|
472 |
+
print("Decision made: novel after round", j)
|
473 |
+
novel = True
|
474 |
+
break
|
475 |
+
if "decision made: not novel" in text.lower():
|
476 |
+
print("Decision made: not novel after round", j)
|
477 |
+
break
|
478 |
+
|
479 |
+
## Format the content in JSON
|
480 |
+
text = format_novelty_json(text)
|
481 |
+
print("text after formating\n", text)
|
482 |
+
## PARSE OUTPUT
|
483 |
+
json_output = extract_json_between_markers(text)
|
484 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
485 |
+
|
486 |
+
## SEARCH FOR PAPERS
|
487 |
+
query = json_output["Query"]
|
488 |
+
papers = search_for_papers(query, result_limit=10)
|
489 |
+
if papers is None:
|
490 |
+
papers_str = "No papers found."
|
491 |
+
|
492 |
+
paper_strings = []
|
493 |
+
for i, paper in enumerate(papers):
|
494 |
+
paper_strings.append(
|
495 |
+
"""{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
|
496 |
+
i=i,
|
497 |
+
title=paper["title"],
|
498 |
+
authors=paper["authors"],
|
499 |
+
venue=paper["venue"],
|
500 |
+
year=paper["year"],
|
501 |
+
cites=paper["citationCount"],
|
502 |
+
abstract=paper["abstract"],
|
503 |
+
)
|
504 |
+
)
|
505 |
+
papers_str = "\n\n".join(paper_strings)
|
506 |
+
|
507 |
+
except Exception as e:
|
508 |
+
print(f"Error: {e}")
|
509 |
+
continue
|
510 |
+
|
511 |
+
idea["novel"] = novel
|
512 |
+
|
513 |
+
# Save results to JSON file
|
514 |
+
results_file = osp.join(base_dir, "ideas.json")
|
515 |
+
with open(results_file, "w") as f:
|
516 |
+
json.dump(ideas, f, indent=4)
|
517 |
+
|
518 |
+
return ideas
|
519 |
+
|
520 |
+
|
521 |
+
if __name__ == "__main__":
|
522 |
+
MAX_NUM_GENERATIONS = 32
|
523 |
+
NUM_REFLECTIONS = 5
|
524 |
+
import argparse
|
525 |
+
|
526 |
+
parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
|
527 |
+
# add type of experiment (nanoGPT, Boston, etc.)
|
528 |
+
parser.add_argument(
|
529 |
+
"--experiment",
|
530 |
+
type=str,
|
531 |
+
default="nanoGPT",
|
532 |
+
help="Experiment to run AI Scientist on.",
|
533 |
+
)
|
534 |
+
parser.add_argument(
|
535 |
+
"--model",
|
536 |
+
type=str,
|
537 |
+
default="deepseek-ai/DeepSeek-V2.5",
|
538 |
+
choices=allchoices,
|
539 |
+
help="Model to use for AI Scientist.",
|
540 |
+
)
|
541 |
+
parser.add_argument(
|
542 |
+
"--skip-idea-generation",
|
543 |
+
action="store_true",
|
544 |
+
help="Skip idea generation and use existing ideas.",
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
"--check-novelty",
|
548 |
+
action="store_true",
|
549 |
+
help="Check novelty of ideas.",
|
550 |
+
)
|
551 |
+
args = parser.parse_args()
|
552 |
+
|
553 |
+
# Create client
|
554 |
+
|
555 |
+
# ------------------------------------------------------------------------------------------------------
|
556 |
+
|
557 |
+
if args.model == "Qwen/Qwen2.5-72B-Instruct":
|
558 |
+
# elif args.model.startswith("hyperbolic"):
|
559 |
+
print(f"Welcome to the PARADISE of debug <generate_scientist.py> {args.model}.")
|
560 |
+
|
561 |
+
import openai
|
562 |
+
import os
|
563 |
+
# client_model = args.model[11:]
|
564 |
+
client_model = args.model
|
565 |
+
client = openai.OpenAI(
|
566 |
+
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
567 |
+
)
|
568 |
+
|
569 |
+
# ------------------------------------------------------------------------------------------------------
|
570 |
+
|
571 |
+
|
572 |
+
elif args.model == "claude-3-5-sonnet-20240620":
|
573 |
+
import anthropic
|
574 |
+
|
575 |
+
print(f"Using Anthropic API with model {args.model}.")
|
576 |
+
client_model = "claude-3-5-sonnet-20240620"
|
577 |
+
client = anthropic.Anthropic()
|
578 |
+
elif args.model.startswith("bedrock") and "claude" in args.model:
|
579 |
+
import anthropic
|
580 |
+
|
581 |
+
# Expects: bedrock/<MODEL_ID>
|
582 |
+
client_model = args.model.split("/")[-1]
|
583 |
+
|
584 |
+
print(f"Using Amazon Bedrock with model {client_model}.")
|
585 |
+
client = anthropic.AnthropicBedrock()
|
586 |
+
elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
|
587 |
+
import openai
|
588 |
+
|
589 |
+
print(f"Using OpenAI API with model {args.model}.")
|
590 |
+
client_model = "gpt-4o-2024-05-13"
|
591 |
+
client = openai.OpenAI()
|
592 |
+
elif args.model == "deepseek-coder-v2-0724":
|
593 |
+
import openai
|
594 |
+
|
595 |
+
print(f"Using OpenAI API with {args.model}.")
|
596 |
+
client_model = "deepseek-coder-v2-0724"
|
597 |
+
client = openai.OpenAI(
|
598 |
+
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
599 |
+
)
|
600 |
+
elif args.model == "llama3.1-405b":
|
601 |
+
import openai
|
602 |
+
|
603 |
+
print(f"Using OpenAI API with {args.model}.")
|
604 |
+
client_model = "meta-llama/llama-3.1-405b-instruct"
|
605 |
+
client = openai.OpenAI(
|
606 |
+
api_key=os.environ["OPENROUTER_API_KEY"],
|
607 |
+
base_url="https://openrouter.ai/api/v1",
|
608 |
+
)
|
609 |
+
elif args.model.startswith("ollama"):
|
610 |
+
import openai
|
611 |
+
|
612 |
+
print(f"Using Ollama with {args.model}.")
|
613 |
+
client_model = args.model.split("/")[-1]
|
614 |
+
# client_model = args.model
|
615 |
+
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
616 |
+
|
617 |
+
else:
|
618 |
+
raise ValueError(f"Model {args.model} not supported.")
|
619 |
+
|
620 |
+
base_dir = osp.join("templates", args.experiment)
|
621 |
+
results_dir = osp.join("results", args.experiment)
|
622 |
+
print("going into line 623...")
|
623 |
+
ideas = generate_ideas(
|
624 |
+
base_dir,
|
625 |
+
client=client,
|
626 |
+
model=client_model,
|
627 |
+
skip_generation=args.skip_idea_generation,
|
628 |
+
max_num_generations=MAX_NUM_GENERATIONS,
|
629 |
+
num_reflections=NUM_REFLECTIONS,
|
630 |
+
)
|
631 |
+
if args.check_novelty:
|
632 |
+
ideas = check_idea_novelty(
|
633 |
+
ideas,
|
634 |
+
base_dir=base_dir,
|
635 |
+
client=client,
|
636 |
+
model=client_model,
|
637 |
+
)
|
ai_scientist/.ipynb_checkpoints/llm-checkpoint.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import backoff
|
4 |
+
import openai
|
5 |
+
|
6 |
+
# Ollama
|
7 |
+
ollama_choices = [
|
8 |
+
"mistral-nemo",
|
9 |
+
"llama3.1",
|
10 |
+
]
|
11 |
+
|
12 |
+
# hyperbolic
|
13 |
+
hyperbolic_choices = [
|
14 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
15 |
+
"meta-llama/Meta-Llama-3.1-70B-Instruct",
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
allchoices = [
|
20 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
21 |
+
"deepseek-ai/DeepSeek-V2.5",
|
22 |
+
"claude-3-5-sonnet-20240620",
|
23 |
+
"gpt-4o-2024-05-13",
|
24 |
+
"deepseek-coder-v2-0724",
|
25 |
+
"llama3.1-405b",
|
26 |
+
# Anthropic Claude models via Amazon Bedrock
|
27 |
+
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
28 |
+
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
29 |
+
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
30 |
+
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
31 |
+
]
|
32 |
+
|
33 |
+
for item in ollama_choices:
|
34 |
+
allchoices.append("ollama/" + item)
|
35 |
+
|
36 |
+
for item in hyperbolic_choices:
|
37 |
+
allchoices.append("hyperbolic/" + item)
|
38 |
+
|
39 |
+
|
40 |
+
# Get N responses from a single message, used for ensembling.
|
41 |
+
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
|
42 |
+
def get_batch_responses_from_llm(
|
43 |
+
msg,
|
44 |
+
client,
|
45 |
+
model,
|
46 |
+
system_message,
|
47 |
+
print_debug=False,
|
48 |
+
msg_history=None,
|
49 |
+
temperature=0.75,
|
50 |
+
n_responses=1,
|
51 |
+
):
|
52 |
+
if msg_history is None:
|
53 |
+
msg_history = []
|
54 |
+
|
55 |
+
if model in [
|
56 |
+
"gpt-4o-2024-05-13",
|
57 |
+
"gpt-4o-mini-2024-07-18",
|
58 |
+
"gpt-4o-2024-08-06",
|
59 |
+
"Qwen/Qwen2.5-72B-Instruct"
|
60 |
+
]:
|
61 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
62 |
+
response = client.chat.completions.create(
|
63 |
+
model=model,
|
64 |
+
messages=[
|
65 |
+
{"role": "system", "content": system_message},
|
66 |
+
*new_msg_history,
|
67 |
+
],
|
68 |
+
temperature=temperature,
|
69 |
+
max_tokens=3000,
|
70 |
+
n=n_responses,
|
71 |
+
stop=None,
|
72 |
+
seed=0,
|
73 |
+
)
|
74 |
+
content = [r.message.content for r in response.choices]
|
75 |
+
new_msg_history = [
|
76 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
77 |
+
]
|
78 |
+
elif model == "deepseek-coder-v2-0724":
|
79 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
80 |
+
response = client.chat.completions.create(
|
81 |
+
model="deepseek-coder",
|
82 |
+
messages=[
|
83 |
+
{"role": "system", "content": system_message},
|
84 |
+
*new_msg_history,
|
85 |
+
],
|
86 |
+
temperature=temperature,
|
87 |
+
max_tokens=3000,
|
88 |
+
n=n_responses,
|
89 |
+
stop=None,
|
90 |
+
)
|
91 |
+
content = [r.message.content for r in response.choices]
|
92 |
+
new_msg_history = [
|
93 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
94 |
+
]
|
95 |
+
|
96 |
+
# ------------------------------------------------------------------------------------------------------
|
97 |
+
|
98 |
+
elif model == "Qwen/Qwen2.5-72B-Instruct":
|
99 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
100 |
+
response = client.chat.completions.create(
|
101 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
102 |
+
messages=[
|
103 |
+
{"role": "system", "content": system_message},
|
104 |
+
*new_msg_history,
|
105 |
+
],
|
106 |
+
temperature=temperature,
|
107 |
+
max_tokens=3000,
|
108 |
+
n=n_responses,
|
109 |
+
stop=None,
|
110 |
+
)
|
111 |
+
content = [r.message.content for r in response.choices]
|
112 |
+
new_msg_history = [
|
113 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
114 |
+
]
|
115 |
+
|
116 |
+
# elif model in hyperbolic_choices:
|
117 |
+
# content, new_msg_history = [], []
|
118 |
+
# for i in range(n_responses):
|
119 |
+
# print(f"Getting {i+1}/{n_responses} response from {model}")
|
120 |
+
# c, hist = get_response_from_llm(
|
121 |
+
# msg,
|
122 |
+
# client,
|
123 |
+
# model,
|
124 |
+
# system_message,
|
125 |
+
# print_debug=False,
|
126 |
+
# msg_history=None,
|
127 |
+
# temperature=temperature,
|
128 |
+
# )
|
129 |
+
# content.append(c)
|
130 |
+
# new_msg_history.append(hist)
|
131 |
+
|
132 |
+
# ------------------------------------------------------------------------------------------------------
|
133 |
+
|
134 |
+
elif model == "llama-3-1-405b-instruct":
|
135 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
136 |
+
response = client.chat.completions.create(
|
137 |
+
model="meta-llama/llama-3.1-405b-instruct",
|
138 |
+
messages=[
|
139 |
+
{"role": "system", "content": system_message},
|
140 |
+
*new_msg_history,
|
141 |
+
],
|
142 |
+
temperature=temperature,
|
143 |
+
max_tokens=3000,
|
144 |
+
n=n_responses,
|
145 |
+
stop=None,
|
146 |
+
)
|
147 |
+
content = [r.message.content for r in response.choices]
|
148 |
+
new_msg_history = [
|
149 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
150 |
+
]
|
151 |
+
elif model == "claude-3-5-sonnet-20240620":
|
152 |
+
content, new_msg_history = [], []
|
153 |
+
for _ in range(n_responses):
|
154 |
+
c, hist = get_response_from_llm(
|
155 |
+
msg,
|
156 |
+
client,
|
157 |
+
model,
|
158 |
+
system_message,
|
159 |
+
print_debug=False,
|
160 |
+
msg_history=None,
|
161 |
+
temperature=temperature,
|
162 |
+
)
|
163 |
+
content.append(c)
|
164 |
+
new_msg_history.append(hist)
|
165 |
+
|
166 |
+
# ollama models
|
167 |
+
elif model in ollama_choices:
|
168 |
+
content, new_msg_history = [], []
|
169 |
+
for i in range(n_responses):
|
170 |
+
print(f"Getting {i+1}/{n_responses} response from {model}")
|
171 |
+
c, hist = get_response_from_llm(
|
172 |
+
msg,
|
173 |
+
client,
|
174 |
+
model,
|
175 |
+
system_message,
|
176 |
+
print_debug=False,
|
177 |
+
msg_history=None,
|
178 |
+
temperature=temperature,
|
179 |
+
)
|
180 |
+
content.append(c)
|
181 |
+
new_msg_history.append(hist)
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Model {model} not supported.")
|
184 |
+
|
185 |
+
if print_debug:
|
186 |
+
# Just print the first one.
|
187 |
+
print()
|
188 |
+
print("*" * 20 + " LLM START " + "*" * 20)
|
189 |
+
for j, msg in enumerate(new_msg_history[0]):
|
190 |
+
print(f'{j}, {msg["role"]}: {msg["content"]}')
|
191 |
+
print(content)
|
192 |
+
print("*" * 21 + " LLM END " + "*" * 21)
|
193 |
+
print()
|
194 |
+
|
195 |
+
return content, new_msg_history
|
196 |
+
|
197 |
+
|
198 |
+
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
|
199 |
+
def get_response_from_llm(
|
200 |
+
msg,
|
201 |
+
client,
|
202 |
+
model,
|
203 |
+
system_message,
|
204 |
+
print_debug=False,
|
205 |
+
msg_history=None,
|
206 |
+
temperature=0.75,
|
207 |
+
):
|
208 |
+
if msg_history is None:
|
209 |
+
msg_history = []
|
210 |
+
|
211 |
+
if model == "claude-3-5-sonnet-20240620":
|
212 |
+
new_msg_history = msg_history + [
|
213 |
+
{
|
214 |
+
"role": "user",
|
215 |
+
"content": [
|
216 |
+
{
|
217 |
+
"type": "text",
|
218 |
+
"text": msg,
|
219 |
+
}
|
220 |
+
],
|
221 |
+
}
|
222 |
+
]
|
223 |
+
response = client.messages.create(
|
224 |
+
model="claude-3-5-sonnet-20240620",
|
225 |
+
max_tokens=3000,
|
226 |
+
temperature=temperature,
|
227 |
+
system=system_message,
|
228 |
+
messages=new_msg_history,
|
229 |
+
)
|
230 |
+
content = response.content[0].text
|
231 |
+
new_msg_history = new_msg_history + [
|
232 |
+
{
|
233 |
+
"role": "assistant",
|
234 |
+
"content": [
|
235 |
+
{
|
236 |
+
"type": "text",
|
237 |
+
"text": content,
|
238 |
+
}
|
239 |
+
],
|
240 |
+
}
|
241 |
+
]
|
242 |
+
# ------------------------------------------------------------------------------------------------------
|
243 |
+
|
244 |
+
elif model in [
|
245 |
+
"gpt-4o-2024-05-13",
|
246 |
+
"gpt-4o-mini-2024-07-18",
|
247 |
+
"gpt-4o-2024-08-06",
|
248 |
+
"Qwen/Qwen2.5-72B-Instruct"
|
249 |
+
]:
|
250 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
251 |
+
response = client.chat.completions.create(
|
252 |
+
model=model,
|
253 |
+
messages=[
|
254 |
+
{"role": "system", "content": system_message},
|
255 |
+
*new_msg_history,
|
256 |
+
],
|
257 |
+
temperature=temperature,
|
258 |
+
max_tokens=3000,
|
259 |
+
n=1,
|
260 |
+
stop=None,
|
261 |
+
seed=0,
|
262 |
+
)
|
263 |
+
content = response.choices[0].message.content
|
264 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
265 |
+
|
266 |
+
|
267 |
+
# ------------------------------------------------------------------------------------------------------
|
268 |
+
|
269 |
+
|
270 |
+
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
|
271 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
272 |
+
response = client.chat.completions.create(
|
273 |
+
model="meta-llama/llama-3.1-405b-instruct",
|
274 |
+
messages=[
|
275 |
+
{"role": "system", "content": system_message},
|
276 |
+
*new_msg_history,
|
277 |
+
],
|
278 |
+
temperature=temperature,
|
279 |
+
max_tokens=3000,
|
280 |
+
n=1,
|
281 |
+
stop=None,
|
282 |
+
)
|
283 |
+
content = response.choices[0].message.content
|
284 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
285 |
+
|
286 |
+
|
287 |
+
elif model in ollama_choices:
|
288 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
289 |
+
response = client.chat.completions.create(
|
290 |
+
model=model,
|
291 |
+
messages=[
|
292 |
+
{"role": "system", "content": system_message},
|
293 |
+
*new_msg_history,
|
294 |
+
],
|
295 |
+
temperature=temperature,
|
296 |
+
max_tokens=6000,
|
297 |
+
n=1,
|
298 |
+
stop=None,
|
299 |
+
seed=0,
|
300 |
+
)
|
301 |
+
content = response.choices[0].message.content
|
302 |
+
# print("\nget_response_from_llm\n")
|
303 |
+
# print(content)
|
304 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
305 |
+
|
306 |
+
else:
|
307 |
+
raise ValueError(f"Model {model} not supported.")
|
308 |
+
|
309 |
+
if print_debug:
|
310 |
+
print()
|
311 |
+
print("*" * 20 + " LLM START " + "*" * 20)
|
312 |
+
for j, msg in enumerate(new_msg_history):
|
313 |
+
print(f'{j}, {msg["role"]}: {msg["content"]}')
|
314 |
+
print(content)
|
315 |
+
print("*" * 21 + " LLM END " + "*" * 21)
|
316 |
+
print()
|
317 |
+
|
318 |
+
return content, new_msg_history
|
319 |
+
|
320 |
+
|
321 |
+
def llm_json_auto_correct(system_prompt: str, user_prompt: str) -> str:
|
322 |
+
import os
|
323 |
+
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1")
|
324 |
+
response = client.chat.completions.create(
|
325 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
326 |
+
temperature=0,
|
327 |
+
messages=[
|
328 |
+
{"role": "system", "content": system_prompt},
|
329 |
+
{"role": "user", "content": user_prompt},
|
330 |
+
],
|
331 |
+
)
|
332 |
+
return response.choices[0].message.content
|
333 |
+
|
334 |
+
|
335 |
+
def extract_json_between_markers(llm_output):
|
336 |
+
json_start_marker = "```json"
|
337 |
+
json_end_marker = "```"
|
338 |
+
|
339 |
+
# Find the start and end indices of the JSON string
|
340 |
+
start_index = llm_output.find(json_start_marker)
|
341 |
+
if start_index != -1:
|
342 |
+
start_index += len(json_start_marker) # Move past the marker
|
343 |
+
end_index = llm_output.find(json_end_marker, start_index)
|
344 |
+
else:
|
345 |
+
return None # JSON markers not found
|
346 |
+
|
347 |
+
if end_index == -1:
|
348 |
+
return None # End marker not found
|
349 |
+
|
350 |
+
# Extract the JSON string
|
351 |
+
json_string = llm_output[start_index:end_index].strip()
|
352 |
+
# print(json_string)
|
353 |
+
try:
|
354 |
+
parsed_json = json.loads(json_string)
|
355 |
+
|
356 |
+
return parsed_json
|
357 |
+
except json.JSONDecodeError:
|
358 |
+
return None # Invalid JSON format
|
ai_scientist/.ipynb_checkpoints/perform_experiments-checkpoint.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import os.path as osp
|
3 |
+
import subprocess
|
4 |
+
from subprocess import TimeoutExpired
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
|
8 |
+
MAX_ITERS = 4
|
9 |
+
MAX_RUNS = 5
|
10 |
+
MAX_STDERR_OUTPUT = 1500
|
11 |
+
|
12 |
+
coder_prompt = """Your goal is to implement the following idea: {title}.
|
13 |
+
The proposed experiment is as follows: {idea}.
|
14 |
+
You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
|
15 |
+
|
16 |
+
First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
|
17 |
+
|
18 |
+
Note that we already provide the vanilla baseline results, so you do not need to re-run it.
|
19 |
+
|
20 |
+
For reference, the baseline results are as follows:
|
21 |
+
|
22 |
+
{baseline_results}
|
23 |
+
|
24 |
+
After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
|
25 |
+
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
|
26 |
+
You can then implement the next thing on your list."""
|
27 |
+
|
28 |
+
|
29 |
+
# RUN EXPERIMENT
|
30 |
+
def run_experiment(folder_name, run_num, timeout=7200):
|
31 |
+
cwd = osp.abspath(folder_name)
|
32 |
+
# COPY CODE SO WE CAN SEE IT.
|
33 |
+
shutil.copy(
|
34 |
+
osp.join(folder_name, "experiment.py"),
|
35 |
+
osp.join(folder_name, f"run_{run_num}.py"),
|
36 |
+
)
|
37 |
+
|
38 |
+
# LAUNCH COMMAND
|
39 |
+
command = [
|
40 |
+
"python",
|
41 |
+
"experiment.py",
|
42 |
+
f"--out_dir=run_{run_num}",
|
43 |
+
]
|
44 |
+
try:
|
45 |
+
result = subprocess.run(
|
46 |
+
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
|
47 |
+
)
|
48 |
+
|
49 |
+
if result.stderr:
|
50 |
+
print(result.stderr, file=sys.stderr)
|
51 |
+
|
52 |
+
if result.returncode != 0:
|
53 |
+
print(f"Run {run_num} failed with return code {result.returncode}")
|
54 |
+
if osp.exists(osp.join(cwd, f"run_{run_num}")):
|
55 |
+
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
|
56 |
+
print(f"Run failed with the following error {result.stderr}")
|
57 |
+
stderr_output = result.stderr
|
58 |
+
if len(stderr_output) > MAX_STDERR_OUTPUT:
|
59 |
+
stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
|
60 |
+
next_prompt = f"Run failed with the following error {stderr_output}"
|
61 |
+
else:
|
62 |
+
with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
|
63 |
+
results = json.load(f)
|
64 |
+
results = {k: v["means"] for k, v in results.items()}
|
65 |
+
|
66 |
+
next_prompt = f"""Run {run_num} completed. Here are the results:
|
67 |
+
{results}
|
68 |
+
|
69 |
+
Decide if you need to re-plan your experiments given the result (you often will not need to).
|
70 |
+
|
71 |
+
Someone else will be using `notes.txt` to perform a writeup on this in the future.
|
72 |
+
Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
|
73 |
+
|
74 |
+
Then, implement the next thing on your list.
|
75 |
+
We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
|
76 |
+
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
|
77 |
+
If you are finished with experiments, respond with 'ALL_COMPLETED'."""
|
78 |
+
return result.returncode, next_prompt
|
79 |
+
except TimeoutExpired:
|
80 |
+
print(f"Run {run_num} timed out after {timeout} seconds")
|
81 |
+
if osp.exists(osp.join(cwd, f"run_{run_num}")):
|
82 |
+
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
|
83 |
+
next_prompt = f"Run timed out after {timeout} seconds"
|
84 |
+
return 1, next_prompt
|
85 |
+
|
86 |
+
|
87 |
+
# RUN PLOTTING
|
88 |
+
def run_plotting(folder_name, timeout=600):
|
89 |
+
cwd = osp.abspath(folder_name)
|
90 |
+
# LAUNCH COMMAND
|
91 |
+
command = [
|
92 |
+
"python",
|
93 |
+
"plot.py",
|
94 |
+
]
|
95 |
+
try:
|
96 |
+
result = subprocess.run(
|
97 |
+
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
|
98 |
+
)
|
99 |
+
|
100 |
+
if result.stderr:
|
101 |
+
print(result.stderr, file=sys.stderr)
|
102 |
+
|
103 |
+
if result.returncode != 0:
|
104 |
+
print(f"Plotting failed with return code {result.returncode}")
|
105 |
+
next_prompt = f"Plotting failed with the following error {result.stderr}"
|
106 |
+
else:
|
107 |
+
next_prompt = ""
|
108 |
+
return result.returncode, next_prompt
|
109 |
+
except TimeoutExpired:
|
110 |
+
print(f"Plotting timed out after {timeout} seconds")
|
111 |
+
next_prompt = f"Plotting timed out after {timeout} seconds"
|
112 |
+
return 1, next_prompt
|
113 |
+
|
114 |
+
|
115 |
+
# PERFORM EXPERIMENTS
|
116 |
+
def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
|
117 |
+
## RUN EXPERIMENT
|
118 |
+
current_iter = 0
|
119 |
+
run = 1
|
120 |
+
next_prompt = coder_prompt.format(
|
121 |
+
title=idea["Title"],
|
122 |
+
idea=idea["Experiment"],
|
123 |
+
max_runs=MAX_RUNS,
|
124 |
+
baseline_results=baseline_results,
|
125 |
+
)
|
126 |
+
while run < MAX_RUNS + 1:
|
127 |
+
if current_iter >= MAX_ITERS:
|
128 |
+
print("Max iterations reached")
|
129 |
+
break
|
130 |
+
coder_out = coder.run(next_prompt)
|
131 |
+
print(coder_out)
|
132 |
+
if "ALL_COMPLETED" in coder_out:
|
133 |
+
break
|
134 |
+
return_code, next_prompt = run_experiment(folder_name, run)
|
135 |
+
if return_code == 0:
|
136 |
+
run += 1
|
137 |
+
current_iter = 0
|
138 |
+
current_iter += 1
|
139 |
+
if current_iter >= MAX_ITERS:
|
140 |
+
print("Not all experiments completed.")
|
141 |
+
return False
|
142 |
+
|
143 |
+
current_iter = 0
|
144 |
+
next_prompt = """
|
145 |
+
Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
|
146 |
+
|
147 |
+
In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
|
148 |
+
|
149 |
+
Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
|
150 |
+
|
151 |
+
We will be running the command `python plot.py` to generate the plots.
|
152 |
+
"""
|
153 |
+
while True:
|
154 |
+
coder_out = coder.run(next_prompt)
|
155 |
+
return_code, next_prompt = run_plotting(folder_name)
|
156 |
+
current_iter += 1
|
157 |
+
if return_code == 0 or current_iter >= MAX_ITERS:
|
158 |
+
break
|
159 |
+
next_prompt = """
|
160 |
+
Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
|
161 |
+
|
162 |
+
Somebody else will be using `notes.txt` to write a report on this in the future.
|
163 |
+
"""
|
164 |
+
coder.run(next_prompt)
|
165 |
+
|
166 |
+
return True
|
ai_scientist/.ipynb_checkpoints/perform_writeup-checkpoint.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import re
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
|
10 |
+
from strictjson import strict_json
|
11 |
+
|
12 |
+
from ai_scientist.generate_ideas import search_for_papers
|
13 |
+
from ai_scientist.llm import (
|
14 |
+
allchoices,
|
15 |
+
extract_json_between_markers,
|
16 |
+
get_response_from_llm,
|
17 |
+
llm_json_auto_correct,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def format_citation_first_json(text):
|
22 |
+
res = strict_json(
|
23 |
+
system_prompt="You are a JSON formatter",
|
24 |
+
user_prompt=text,
|
25 |
+
return_as_json=True,
|
26 |
+
output_format={
|
27 |
+
"Description": "A precise description of the required edit, along with the proposed text and location where it should be made",
|
28 |
+
"Query": "The search query to find the paper (e.g. attention is all you need)",
|
29 |
+
},
|
30 |
+
llm=llm_json_auto_correct,
|
31 |
+
)
|
32 |
+
text = json.loads(res)
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def format_citation_second_json(text):
|
37 |
+
res = strict_json(
|
38 |
+
system_prompt="You are a JSON formatter",
|
39 |
+
user_prompt=text,
|
40 |
+
return_as_json=True,
|
41 |
+
output_format={
|
42 |
+
"Selected": "A list of the indices of the selected papers to be cited, e.g. '[0, 1]'. Can be '[]' if no papers are selected. This must be a string",
|
43 |
+
"Description": "Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex",
|
44 |
+
},
|
45 |
+
llm=llm_json_auto_correct,
|
46 |
+
)
|
47 |
+
text = json.loads(res)
|
48 |
+
return text
|
49 |
+
|
50 |
+
|
51 |
+
# GENERATE LATEX
|
52 |
+
def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_corrections=5):
|
53 |
+
folder = osp.abspath(folder_name)
|
54 |
+
cwd = osp.join(folder, "latex") # Fixed potential issue with path
|
55 |
+
writeup_file = osp.join(cwd, "template.tex")
|
56 |
+
|
57 |
+
# Check all references are valid and in the references.bib file
|
58 |
+
with open(writeup_file, "r") as f:
|
59 |
+
tex_text = f.read()
|
60 |
+
cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
|
61 |
+
references_bib = re.search(
|
62 |
+
r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
|
63 |
+
tex_text,
|
64 |
+
re.DOTALL,
|
65 |
+
)
|
66 |
+
if references_bib is None:
|
67 |
+
print("No references.bib found in template.tex")
|
68 |
+
return
|
69 |
+
bib_text = references_bib.group(1)
|
70 |
+
cites = [cite.strip() for item in cites for cite in item.split(",")]
|
71 |
+
for cite in cites:
|
72 |
+
if cite not in bib_text:
|
73 |
+
print(f"Reference {cite} not found in references.")
|
74 |
+
prompt = f"""Reference {cite} not found in references.bib. Is this included under a different name?
|
75 |
+
If so, please modify the citation in template.tex to match the name in references.bib at the top. Otherwise, remove the cite."""
|
76 |
+
coder.run(prompt)
|
77 |
+
|
78 |
+
# Check all included figures are actually in the directory.
|
79 |
+
with open(writeup_file, "r") as f:
|
80 |
+
tex_text = f.read()
|
81 |
+
referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
|
82 |
+
all_figs = [f for f in os.listdir(folder) if f.endswith(".png")]
|
83 |
+
for figure in referenced_figs:
|
84 |
+
if figure not in all_figs:
|
85 |
+
print(f"Figure {figure} not found in directory.")
|
86 |
+
prompt = f"""The image {figure} not found in the directory. The images in the directory are: {all_figs}.
|
87 |
+
Please ensure that the figure is in the directory and that the filename is correct. Check the notes to see what each figure contains."""
|
88 |
+
coder.run(prompt)
|
89 |
+
|
90 |
+
# Remove duplicate figures.
|
91 |
+
with open(writeup_file, "r") as f:
|
92 |
+
tex_text = f.read()
|
93 |
+
referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
|
94 |
+
duplicates = {x for x in referenced_figs if referenced_figs.count(x) > 1}
|
95 |
+
if duplicates:
|
96 |
+
for dup in duplicates:
|
97 |
+
print(f"Duplicate figure found: {dup}.")
|
98 |
+
prompt = f"""Duplicate figures found: {dup}. Ensure any figure is only included once.
|
99 |
+
If duplicated, identify the best location for the figure and remove any other."""
|
100 |
+
coder.run(prompt)
|
101 |
+
|
102 |
+
# Remove duplicate section headers.
|
103 |
+
with open(writeup_file, "r") as f:
|
104 |
+
tex_text = f.read()
|
105 |
+
sections = re.findall(r"\\section{([^}]*)}", tex_text)
|
106 |
+
duplicates = {x for x in sections if sections.count(x) > 1}
|
107 |
+
if duplicates:
|
108 |
+
for dup in duplicates:
|
109 |
+
print(f"Duplicate section header found: {dup}")
|
110 |
+
prompt = f"""Duplicate section header found: {dup}. Ensure any section header is declared once.
|
111 |
+
If duplicated, identify the best location for the section header and remove any other."""
|
112 |
+
coder.run(prompt)
|
113 |
+
|
114 |
+
# Iteratively fix any LaTeX bugs
|
115 |
+
for i in range(num_error_corrections):
|
116 |
+
# Filter trivial bugs in chktex
|
117 |
+
check_output = os.popen(f"chktex {writeup_file} -q -n2 -n24 -n13 -n1").read()
|
118 |
+
if check_output:
|
119 |
+
prompt = f"""Please fix the following LaTeX errors in `template.tex` guided by the output of `chktek`:
|
120 |
+
{check_output}.
|
121 |
+
|
122 |
+
Make the minimal fix required and do not remove or change any packages.
|
123 |
+
Pay attention to any accidental uses of HTML syntax, e.g. </end instead of \\end.
|
124 |
+
"""
|
125 |
+
coder.run(prompt)
|
126 |
+
else:
|
127 |
+
break
|
128 |
+
compile_latex(cwd, pdf_file, timeout=timeout)
|
129 |
+
|
130 |
+
|
131 |
+
def compile_latex(cwd, pdf_file, timeout=30):
|
132 |
+
print("GENERATING LATEX")
|
133 |
+
|
134 |
+
commands = [
|
135 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
136 |
+
["bibtex", "template"],
|
137 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
138 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
139 |
+
]
|
140 |
+
|
141 |
+
for command in commands:
|
142 |
+
try:
|
143 |
+
result = subprocess.run(
|
144 |
+
command,
|
145 |
+
cwd=cwd,
|
146 |
+
stdout=subprocess.PIPE,
|
147 |
+
stderr=subprocess.PIPE,
|
148 |
+
text=True,
|
149 |
+
timeout=timeout,
|
150 |
+
)
|
151 |
+
print("Standard Output:\n", result.stdout)
|
152 |
+
print("Standard Error:\n", result.stderr)
|
153 |
+
except subprocess.TimeoutExpired:
|
154 |
+
print(f"Latex timed out after {timeout} seconds")
|
155 |
+
except subprocess.CalledProcessError as e:
|
156 |
+
print(f"Error running command {' '.join(command)}: {e}")
|
157 |
+
|
158 |
+
print("FINISHED GENERATING LATEX")
|
159 |
+
|
160 |
+
# Attempt to move the PDF to the desired location
|
161 |
+
try:
|
162 |
+
shutil.move(osp.join(cwd, "template.pdf"), pdf_file)
|
163 |
+
except FileNotFoundError:
|
164 |
+
print("Failed to rename PDF.")
|
165 |
+
|
166 |
+
|
167 |
+
per_section_tips = {
|
168 |
+
"Abstract": """
|
169 |
+
- TL;DR of the paper
|
170 |
+
- What are we trying to do and why is it relevant?
|
171 |
+
- Why is this hard?
|
172 |
+
- How do we solve it (i.e. our contribution!)
|
173 |
+
- How do we verify that we solved it (e.g. Experiments and results)
|
174 |
+
|
175 |
+
Please make sure the abstract reads smoothly and is well-motivated. This should be one continuous paragraph with no breaks between the lines.
|
176 |
+
""",
|
177 |
+
"Introduction": """
|
178 |
+
- Longer version of the Abstract, i.e. of the entire paper
|
179 |
+
- What are we trying to do and why is it relevant?
|
180 |
+
- Why is this hard?
|
181 |
+
- How do we solve it (i.e. our contribution!)
|
182 |
+
- How do we verify that we solved it (e.g. Experiments and results)
|
183 |
+
- New trend: specifically list your contributions as bullet points
|
184 |
+
- Extra space? Future work!
|
185 |
+
""",
|
186 |
+
"Related Work": """
|
187 |
+
- Academic siblings of our work, i.e. alternative attempts in literature at trying to solve the same problem.
|
188 |
+
- Goal is to “Compare and contrast” - how does their approach differ in either assumptions or method? If their method is applicable to our Problem Setting I expect a comparison in the experimental section. If not, there needs to be a clear statement why a given method is not applicable.
|
189 |
+
- Note: Just describing what another paper is doing is not enough. We need to compare and contrast.
|
190 |
+
""",
|
191 |
+
"Background": """
|
192 |
+
- Academic Ancestors of our work, i.e. all concepts and prior work that are required for understanding our method.
|
193 |
+
- Usually includes a subsection, Problem Setting, which formally introduces the problem setting and notation (Formalism) for our method. Highlights any specific assumptions that are made that are unusual.
|
194 |
+
- Note: If our paper introduces a novel problem setting as part of its contributions, it's best to have a separate Section.
|
195 |
+
""",
|
196 |
+
"Method": """
|
197 |
+
- What we do. Why we do it. All described using the general Formalism introduced in the Problem Setting and building on top of the concepts / foundations introduced in Background.
|
198 |
+
""",
|
199 |
+
"Experimental Setup": """
|
200 |
+
- How do we test that our stuff works? Introduces a specific instantiation of the Problem Setting and specific implementation details of our Method for this Problem Setting.
|
201 |
+
- Do not imagine unknown hardware details.
|
202 |
+
- Includes a description of the dataset, evaluation metrics, important hyperparameters, and implementation details.
|
203 |
+
""",
|
204 |
+
"Results": """
|
205 |
+
- Shows the results of running Method on our problem described in Experimental Setup.
|
206 |
+
- Includes statements on hyperparameters and other potential issues of fairness.
|
207 |
+
- Only includes results that have actually been run and saved in the logs. Do not hallucinate results that don't exist.
|
208 |
+
- If results exist: compares to baselines and includes statistics and confidence intervals.
|
209 |
+
- If results exist: includes ablation studies to show that specific parts of the method are relevant.
|
210 |
+
- Discusses limitations of the method.
|
211 |
+
- Make sure to include all the results from the experiments, and include all relevant figures.
|
212 |
+
""",
|
213 |
+
"Conclusion": """
|
214 |
+
- Brief recap of the entire paper.
|
215 |
+
- To keep going with the analogy, you can think of future work as (potential) academic offspring.
|
216 |
+
""",
|
217 |
+
}
|
218 |
+
|
219 |
+
error_list = """- Unenclosed math symbols
|
220 |
+
- Only reference figures that exist in our directory
|
221 |
+
- LaTeX syntax errors
|
222 |
+
- Numerical results that do not come from explicit experiments and logs
|
223 |
+
- Repeatedly defined figure labels
|
224 |
+
- References to papers that are not in the .bib file, DO NOT ADD ANY NEW CITATIONS!
|
225 |
+
- Unnecessary verbosity or repetition, unclear text
|
226 |
+
- Results or insights in the `notes.txt` that have not yet need included
|
227 |
+
- Any relevant figures that have not yet been included in the text
|
228 |
+
- Closing any \\begin{{figure}} with a \\end{{figure}} and \\begin{{table}} with a \\end{{table}}, etc.
|
229 |
+
- Duplicate headers, e.g. duplicated \\section{{Introduction}} or \\end{{document}}
|
230 |
+
- Unescaped symbols, e.g. shakespeare_char should be shakespeare\\_char in text
|
231 |
+
- Incorrect closing of environments, e.g. </end{{figure}}> instead of \\end{{figure}}
|
232 |
+
"""
|
233 |
+
|
234 |
+
refinement_prompt = (
|
235 |
+
"""Great job! Now criticize and refine only the {section} that you just wrote.
|
236 |
+
Make this complete in this pass, do not leave any placeholders.
|
237 |
+
|
238 |
+
Pay particular attention to fixing any errors such as:
|
239 |
+
"""
|
240 |
+
+ error_list
|
241 |
+
)
|
242 |
+
|
243 |
+
second_refinement_prompt = (
|
244 |
+
"""Criticize and refine the {section} only. Recall the advice:
|
245 |
+
{tips}
|
246 |
+
Make this complete in this pass, do not leave any placeholders.
|
247 |
+
|
248 |
+
Pay attention to how it fits in with the rest of the paper.
|
249 |
+
Identify any redundancies (e.g. repeated figures or repeated text), if there are any, decide where in the paper things should be cut.
|
250 |
+
Identify where we can save space, and be more concise without weakening the message of the text.
|
251 |
+
Fix any remaining errors as before:
|
252 |
+
"""
|
253 |
+
+ error_list
|
254 |
+
)
|
255 |
+
|
256 |
+
# CITATION HELPERS
|
257 |
+
citation_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
|
258 |
+
You have already written an initial draft of the paper and now you are looking to add missing citations to related papers throughout the paper.
|
259 |
+
The related work section already has some initial comments on which papers to add and discuss.
|
260 |
+
|
261 |
+
Focus on completing the existing write-up and do not add entirely new elements unless necessary.
|
262 |
+
Ensure every point in the paper is substantiated with sufficient evidence.
|
263 |
+
Feel free to add more cites to a particular point if there is only one or two references.
|
264 |
+
Ensure no paper is cited without a corresponding reference in the `references.bib` file.
|
265 |
+
Ensure each paragraph of the related work has sufficient background, e.g. a few papers cited.
|
266 |
+
You will be given access to the Semantic Scholar API, only add citations that you have found using the API.
|
267 |
+
Aim to discuss a broad range of relevant papers, not just the most popular ones.
|
268 |
+
Make sure not to copy verbatim from prior literature to avoid plagiarism.
|
269 |
+
|
270 |
+
You will be prompted to give a precise description of where and how to add the cite, and a search query for the paper to be cited.
|
271 |
+
Finally, you will select the most relevant cite from the search results (top 10 results will be shown).
|
272 |
+
You will have {total_rounds} rounds to add to the references, but do not need to use them all.
|
273 |
+
|
274 |
+
DO NOT ADD A CITATION THAT ALREADY EXISTS!"""
|
275 |
+
|
276 |
+
citation_first_prompt = '''Round {current_round}/{total_rounds}:
|
277 |
+
|
278 |
+
You have written this LaTeX draft so far:
|
279 |
+
|
280 |
+
"""
|
281 |
+
{draft}
|
282 |
+
"""
|
283 |
+
|
284 |
+
Identify the most important citation that you still need to add, and the query to find the paper.
|
285 |
+
|
286 |
+
Respond in the following format:
|
287 |
+
|
288 |
+
THOUGHT:
|
289 |
+
<THOUGHT>
|
290 |
+
|
291 |
+
RESPONSE:
|
292 |
+
```json
|
293 |
+
<JSON>
|
294 |
+
```
|
295 |
+
|
296 |
+
In <THOUGHT>, first briefly reason over the paper and identify where citations should be added.
|
297 |
+
If no more citations are needed, add "No more citations needed" to your thoughts.
|
298 |
+
Do not add "No more citations needed" if you are adding citations this round.
|
299 |
+
|
300 |
+
In <JSON>, respond in JSON format with the following fields:
|
301 |
+
- "Description": A precise description of the required edit, along with the proposed text and location where it should be made.
|
302 |
+
- "Query": The search query to find the paper (e.g. attention is all you need).
|
303 |
+
|
304 |
+
Ensure the description is sufficient to make the change without further context. Someone else will make the change.
|
305 |
+
The query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
|
306 |
+
This JSON will be automatically parsed, so ensure the format is precise.'''
|
307 |
+
|
308 |
+
citation_second_prompt = """Search has recovered the following articles:
|
309 |
+
|
310 |
+
{papers}
|
311 |
+
|
312 |
+
Respond in the following format:
|
313 |
+
|
314 |
+
THOUGHT:
|
315 |
+
<THOUGHT>
|
316 |
+
|
317 |
+
RESPONSE:
|
318 |
+
```json
|
319 |
+
<JSON>
|
320 |
+
```
|
321 |
+
|
322 |
+
In <THOUGHT>, first briefly reason over the search results and identify which citation best fits your paper and the location is to be added at.
|
323 |
+
If none are appropriate, add "Do not add any" to your thoughts.
|
324 |
+
|
325 |
+
In <JSON>, respond in JSON format with the following fields:
|
326 |
+
- "Selected": A list of the indices of the selected papers to be cited, e.g. "[0, 1]". Can be "[]" if no papers are selected. This must be a string.
|
327 |
+
- "Description": Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex!!!
|
328 |
+
|
329 |
+
Do not select papers that are already in the `references.bib` file at the top of the draft, or if the same citation exists under a different name.
|
330 |
+
This JSON will be automatically parsed, so ensure the format is precise."""
|
331 |
+
|
332 |
+
|
333 |
+
def get_citation_aider_prompt(
|
334 |
+
client, model, draft, current_round, total_rounds
|
335 |
+
) -> Tuple[Optional[str], bool]:
|
336 |
+
msg_history = []
|
337 |
+
try:
|
338 |
+
text, msg_history = get_response_from_llm(
|
339 |
+
citation_first_prompt.format(
|
340 |
+
draft=draft, current_round=current_round, total_rounds=total_rounds
|
341 |
+
),
|
342 |
+
client=client,
|
343 |
+
model=model,
|
344 |
+
system_message=citation_system_msg.format(total_rounds=total_rounds),
|
345 |
+
msg_history=msg_history,
|
346 |
+
)
|
347 |
+
if "No more citations needed" in text:
|
348 |
+
print("No more citations needed.")
|
349 |
+
return None, True
|
350 |
+
|
351 |
+
## PARSE OUTPUT
|
352 |
+
json_output = format_citation_first_json(text)
|
353 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
354 |
+
query = json_output["Query"]
|
355 |
+
papers = search_for_papers(query)
|
356 |
+
except Exception as e:
|
357 |
+
print(f"Error: {e}")
|
358 |
+
return None, False
|
359 |
+
|
360 |
+
if papers is None:
|
361 |
+
print("No papers found.")
|
362 |
+
return None, False
|
363 |
+
|
364 |
+
paper_strings = []
|
365 |
+
for i, paper in enumerate(papers):
|
366 |
+
paper_strings.append(
|
367 |
+
"""{i}: {title}. {authors}. {venue}, {year}.\nAbstract: {abstract}""".format(
|
368 |
+
i=i,
|
369 |
+
title=paper["title"],
|
370 |
+
authors=paper["authors"],
|
371 |
+
venue=paper["venue"],
|
372 |
+
year=paper["year"],
|
373 |
+
abstract=paper["abstract"],
|
374 |
+
)
|
375 |
+
)
|
376 |
+
papers_str = "\n\n".join(paper_strings)
|
377 |
+
|
378 |
+
try:
|
379 |
+
text, msg_history = get_response_from_llm(
|
380 |
+
citation_second_prompt.format(
|
381 |
+
papers=papers_str,
|
382 |
+
current_round=current_round,
|
383 |
+
total_rounds=total_rounds,
|
384 |
+
),
|
385 |
+
client=client,
|
386 |
+
model=model,
|
387 |
+
system_message=citation_system_msg.format(total_rounds=total_rounds),
|
388 |
+
msg_history=msg_history,
|
389 |
+
)
|
390 |
+
if "Do not add any" in text:
|
391 |
+
print("Do not add any.")
|
392 |
+
return None, False
|
393 |
+
## PARSE OUTPUT
|
394 |
+
json_output = format_citation_second_json(text)
|
395 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
396 |
+
desc = json_output["Description"]
|
397 |
+
selected_papers = json_output["Selected"]
|
398 |
+
selected_papers = str(selected_papers)
|
399 |
+
|
400 |
+
# convert to list
|
401 |
+
if selected_papers != "[]":
|
402 |
+
selected_papers = list(map(int, selected_papers.strip("[]").split(",")))
|
403 |
+
assert all(
|
404 |
+
[0 <= i < len(papers) for i in selected_papers]
|
405 |
+
), "Invalid paper index"
|
406 |
+
bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers]
|
407 |
+
bibtex_string = "\n".join(bibtexs)
|
408 |
+
else:
|
409 |
+
return None, False
|
410 |
+
|
411 |
+
except Exception as e:
|
412 |
+
print(f"Error: {e}")
|
413 |
+
return None, False
|
414 |
+
|
415 |
+
# Add citation to draft
|
416 |
+
aider_format = '''The following citations have just been added to the end of the `references.bib` file definition at the top of the file:
|
417 |
+
"""
|
418 |
+
{bibtex}
|
419 |
+
"""
|
420 |
+
You do not need to add them yourself.
|
421 |
+
ABSOLUTELY DO NOT ADD IT AGAIN!!!
|
422 |
+
|
423 |
+
Make the proposed change to the draft incorporating these new cites:
|
424 |
+
{description}
|
425 |
+
|
426 |
+
Use your judgment for whether these should be cited anywhere else.
|
427 |
+
Make sure that any citation precisely matches the name in `references.bib`. Change its name to the correct name in the bibtex if needed.
|
428 |
+
Ensure the citation is well-integrated into the text.'''
|
429 |
+
|
430 |
+
aider_prompt = (
|
431 |
+
aider_format.format(bibtex=bibtex_string, description=desc)
|
432 |
+
+ """\n You must use \cite or \citet to reference papers, do not manually type out author names."""
|
433 |
+
)
|
434 |
+
return aider_prompt, False
|
435 |
+
|
436 |
+
|
437 |
+
# PERFORM WRITEUP
|
438 |
+
def perform_writeup(
|
439 |
+
idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20
|
440 |
+
):
|
441 |
+
# CURRENTLY ASSUMES LATEX
|
442 |
+
abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
|
443 |
+
|
444 |
+
First, please fill in the "Title" and "Abstract" sections of the writeup.
|
445 |
+
|
446 |
+
Some tips are provided below:
|
447 |
+
{per_section_tips["Abstract"]}
|
448 |
+
|
449 |
+
Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
|
450 |
+
|
451 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
452 |
+
"""
|
453 |
+
coder_out = coder.run(abstract_prompt)
|
454 |
+
coder_out = coder.run(
|
455 |
+
refinement_prompt.format(section="Abstract")
|
456 |
+
.replace(r"{{", "{")
|
457 |
+
.replace(r"}}", "}")
|
458 |
+
)
|
459 |
+
for section in [
|
460 |
+
"Introduction",
|
461 |
+
"Background",
|
462 |
+
"Method",
|
463 |
+
"Experimental Setup",
|
464 |
+
"Results",
|
465 |
+
"Conclusion",
|
466 |
+
]:
|
467 |
+
section_prompt = f"""Please fill in the {section} of the writeup. Some tips are provided below:
|
468 |
+
{per_section_tips[section]}
|
469 |
+
|
470 |
+
Be sure to use \cite or \citet where relevant, referring to the works provided in the file.
|
471 |
+
Do not cite anything that is not already in `references.bib`. Do not add any new entries to this.
|
472 |
+
|
473 |
+
Keep the experimental results (figures and tables) only in the Results section, and make sure that any captions are filled in.
|
474 |
+
In this pass, do not reference anything in later sections of the paper.
|
475 |
+
|
476 |
+
Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
|
477 |
+
|
478 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
479 |
+
"""
|
480 |
+
coder_out = coder.run(section_prompt)
|
481 |
+
coder_out = coder.run(
|
482 |
+
refinement_prompt.format(section=section)
|
483 |
+
.replace(r"{{", "{")
|
484 |
+
.replace(r"}}", "}")
|
485 |
+
)
|
486 |
+
|
487 |
+
# SKETCH THE RELATED WORK
|
488 |
+
section_prompt = f"""Please fill in the Related Work of the writeup. Some tips are provided below:
|
489 |
+
|
490 |
+
{per_section_tips["Related Work"]}
|
491 |
+
|
492 |
+
For this section, very briefly sketch out the structure of the section, and clearly indicate what papers you intend to include.
|
493 |
+
Do this all in LaTeX comments using %.
|
494 |
+
The related work should be concise, only plan to discuss the most relevant work.
|
495 |
+
Do not modify `references.bib` to add any new citations, this will be filled in at a later stage.
|
496 |
+
|
497 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
498 |
+
"""
|
499 |
+
coder_out = coder.run(section_prompt)
|
500 |
+
|
501 |
+
# Fill paper with cites.
|
502 |
+
for _ in range(num_cite_rounds):
|
503 |
+
with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
|
504 |
+
draft = f.read()
|
505 |
+
prompt, done = get_citation_aider_prompt(
|
506 |
+
cite_client, cite_model, draft, _, num_cite_rounds
|
507 |
+
)
|
508 |
+
if done:
|
509 |
+
break
|
510 |
+
if prompt is not None:
|
511 |
+
# extract bibtex string
|
512 |
+
bibtex_string = prompt.split('"""')[1]
|
513 |
+
# insert this into draft before the "\end{filecontents}" line
|
514 |
+
search_str = r"\end{filecontents}"
|
515 |
+
draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
|
516 |
+
with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
|
517 |
+
f.write(draft)
|
518 |
+
coder_out = coder.run(prompt)
|
519 |
+
|
520 |
+
coder_out = coder.run(
|
521 |
+
refinement_prompt.format(section="Related Work")
|
522 |
+
.replace(r"{{", "{")
|
523 |
+
.replace(r"}}", "}")
|
524 |
+
)
|
525 |
+
|
526 |
+
## SECOND REFINEMENT LOOP
|
527 |
+
coder.run(
|
528 |
+
"""Great job! Now that there is a complete draft of the entire paper, let's refine each section again.
|
529 |
+
First, re-think the Title if necessary. Keep this concise and descriptive of the paper's concept, but try by creative with it."""
|
530 |
+
)
|
531 |
+
for section in [
|
532 |
+
"Abstract",
|
533 |
+
"Related Work",
|
534 |
+
"Introduction",
|
535 |
+
"Background",
|
536 |
+
"Method",
|
537 |
+
"Experimental Setup",
|
538 |
+
"Results",
|
539 |
+
"Conclusion",
|
540 |
+
]:
|
541 |
+
coder_out = coder.run(
|
542 |
+
second_refinement_prompt.format(
|
543 |
+
section=section, tips=per_section_tips[section]
|
544 |
+
)
|
545 |
+
.replace(r"{{", "{")
|
546 |
+
.replace(r"}}", "}")
|
547 |
+
)
|
548 |
+
|
549 |
+
generate_latex(coder, folder_name, f"{folder_name}/{idea['Name']}.pdf")
|
550 |
+
|
551 |
+
|
552 |
+
if __name__ == "__main__":
|
553 |
+
import json
|
554 |
+
|
555 |
+
from aider.coders import Coder
|
556 |
+
from aider.io import InputOutput
|
557 |
+
from aider.models import Model
|
558 |
+
|
559 |
+
parser = argparse.ArgumentParser(description="Perform writeup for a project")
|
560 |
+
parser.add_argument("--folder", type=str)
|
561 |
+
parser.add_argument("--no-writing", action="store_true", help="Only generate")
|
562 |
+
parser.add_argument(
|
563 |
+
"--model",
|
564 |
+
type=str,
|
565 |
+
default="gpt-4o-2024-05-13",
|
566 |
+
choices=allchoices,
|
567 |
+
help="Model to use for AI Scientist.",
|
568 |
+
)
|
569 |
+
args = parser.parse_args()
|
570 |
+
if args.model == "claude-3-5-sonnet-20240620":
|
571 |
+
import anthropic
|
572 |
+
|
573 |
+
print(f"Using Anthropic API with model {args.model}.")
|
574 |
+
client_model = "claude-3-5-sonnet-20240620"
|
575 |
+
client = anthropic.Anthropic()
|
576 |
+
elif args.model.startswith("bedrock") and "claude" in args.model:
|
577 |
+
import anthropic
|
578 |
+
|
579 |
+
# Expects: bedrock/<MODEL_ID>
|
580 |
+
client_model = args.model.split("/")[-1]
|
581 |
+
|
582 |
+
print(f"Using Amazon Bedrock with model {client_model}.")
|
583 |
+
client = anthropic.AnthropicBedrock()
|
584 |
+
elif args.model.startswith("vertex_ai") and "claude" in args.model:
|
585 |
+
import anthropic
|
586 |
+
|
587 |
+
# Expects: vertex_ai/<MODEL_ID>
|
588 |
+
client_model = args.model.split("/")[-1]
|
589 |
+
|
590 |
+
print(f"Using Vertex AI with model {client_model}.")
|
591 |
+
client = anthropic.AnthropicVertex()
|
592 |
+
elif args.model == "gpt-4o-2024-05-13":
|
593 |
+
import openai
|
594 |
+
|
595 |
+
print(f"Using OpenAI API with model {args.model}.")
|
596 |
+
client_model = "gpt-4o-2024-05-13"
|
597 |
+
client = openai.OpenAI()
|
598 |
+
elif args.model == "deepseek-coder-v2-0724":
|
599 |
+
import openai
|
600 |
+
|
601 |
+
print(f"Using OpenAI API with {args.model}.")
|
602 |
+
client_model = "deepseek-coder-v2-0724"
|
603 |
+
client = openai.OpenAI(
|
604 |
+
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
|
605 |
+
)
|
606 |
+
|
607 |
+
# ----------------------------------------------------
|
608 |
+
|
609 |
+
elif args.model == "Qwen/Qwen2.5-72B-Instruct":
|
610 |
+
# elif args.model.startswith("hyperbolic"):
|
611 |
+
print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
|
612 |
+
|
613 |
+
import openai
|
614 |
+
import os
|
615 |
+
# client_model = args.model[11:]
|
616 |
+
client_model = args.model
|
617 |
+
client = openai.OpenAI(
|
618 |
+
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
619 |
+
)
|
620 |
+
# ----------------------------------------------------
|
621 |
+
elif args.model == "llama3.1-405b":
|
622 |
+
import openai
|
623 |
+
|
624 |
+
print(f"Using OpenAI API with {args.model}.")
|
625 |
+
client_model = "meta-llama/llama-3.1-405b-instruct"
|
626 |
+
client = openai.OpenAI(
|
627 |
+
api_key=os.environ["OPENROUTER_API_KEY"],
|
628 |
+
base_url="https://openrouter.ai/api/v1",
|
629 |
+
)
|
630 |
+
|
631 |
+
elif args.model.startswith("ollama"):
|
632 |
+
import openai
|
633 |
+
|
634 |
+
print(f"Using Ollama with {args.model}.")
|
635 |
+
client_model = args.model.split("/")[-1]
|
636 |
+
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
637 |
+
else:
|
638 |
+
raise ValueError(f"Model {args.model} not recognized.")
|
639 |
+
|
640 |
+
|
641 |
+
print("Make sure you cleaned the Aider logs if re-generating the writeup!")
|
642 |
+
folder_name = args.folder
|
643 |
+
idea_name = osp.basename(folder_name)
|
644 |
+
exp_file = osp.join(folder_name, "experiment.py")
|
645 |
+
vis_file = osp.join(folder_name, "plot.py")
|
646 |
+
notes = osp.join(folder_name, "notes.txt")
|
647 |
+
|
648 |
+
model = args.model
|
649 |
+
|
650 |
+
writeup_file = osp.join(folder_name, "latex", "template.tex")
|
651 |
+
ideas_file = osp.join(folder_name, "ideas.json")
|
652 |
+
with open(ideas_file, "r") as f:
|
653 |
+
ideas = json.load(f)
|
654 |
+
for idea in ideas:
|
655 |
+
if idea["Name"] in idea_name:
|
656 |
+
print(f"Found idea: {idea['Name']}")
|
657 |
+
break
|
658 |
+
if idea["Name"] not in idea_name:
|
659 |
+
raise ValueError(f"Idea {idea_name} not found")
|
660 |
+
fnames = [exp_file, writeup_file, notes]
|
661 |
+
io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
|
662 |
+
|
663 |
+
|
664 |
+
|
665 |
+
# AIDER CHAT INITIALIZATION CODE
|
666 |
+
|
667 |
+
if args.model == "deepseek-ai/DeepSeek-V2.5":
|
668 |
+
print("aider chosen deepseek")
|
669 |
+
main_model = Model("deepseek-chat")
|
670 |
+
|
671 |
+
elif args.model == "deepseek-coder-v2-0724":
|
672 |
+
main_model = Model("deepseek-ai/DeepSeek-V2.5")
|
673 |
+
|
674 |
+
elif args.model == "llama3.1-405b":
|
675 |
+
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
|
676 |
+
|
677 |
+
# ----------------------------------------------------
|
678 |
+
|
679 |
+
elif args.model == "hyperbolic/Qwen/Qwen2.5-72B-Instruct":
|
680 |
+
print("aider model chosen")
|
681 |
+
# main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
|
682 |
+
main_model = Model("hyperbolic/Qwen/Qwen2.5-72B-Instruct")
|
683 |
+
|
684 |
+
elif args.model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
|
685 |
+
main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
|
686 |
+
|
687 |
+
# ----------------------------------------------------
|
688 |
+
|
689 |
+
else:
|
690 |
+
print("hello world")
|
691 |
+
main_model = Model(model)
|
692 |
+
|
693 |
+
coder = Coder.create(
|
694 |
+
main_model=main_model,
|
695 |
+
fnames=fnames,
|
696 |
+
io=io,
|
697 |
+
stream=False,
|
698 |
+
use_git=False,
|
699 |
+
edit_format="diff",
|
700 |
+
)
|
701 |
+
if args.no_writing:
|
702 |
+
generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
|
703 |
+
else:
|
704 |
+
try:
|
705 |
+
perform_writeup(idea, folder_name, coder, client, client_model)
|
706 |
+
except Exception as e:
|
707 |
+
print(f"Failed to perform writeup: {e}")
|
ai_scientist/Untitled.ipynb
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "011cecbb-41c0-4943-a37b-73ebd91d1e46",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import openai\n",
|
11 |
+
"import os \n",
|
12 |
+
"\n",
|
13 |
+
"client_model = \"deepseek-ai/DeepSeek-V2.5\"\n",
|
14 |
+
"client = openai.OpenAI(\n",
|
15 |
+
" # api_key=os.environ[\"DEEPSEEK_API_KEY\"], base_url=\"https://api.deepseek.com\"\n",
|
16 |
+
" api_key=DEEPSEEK_API_KEY, base_url=\"https://api.hyperbolic.xyz/v1\"\n",
|
17 |
+
")\n"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": 1,
|
23 |
+
"id": "d55b29d2-bcf9-4e92-9b79-1f2e20478069",
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"DEEPSEEK_API_KEY=\"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJqb2huZG9lIn0.adqb-sj6G4xl7w7t9A4oRUBf1jNmcrc-0IcHjGhbz3o\""
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": null,
|
33 |
+
"id": "86a236ee-791f-4ae7-92c4-c10d8b085f76",
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"from aider."
|
38 |
+
]
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"metadata": {
|
42 |
+
"kernelspec": {
|
43 |
+
"display_name": "Python 3 (ipykernel)",
|
44 |
+
"language": "python",
|
45 |
+
"name": "python3"
|
46 |
+
},
|
47 |
+
"language_info": {
|
48 |
+
"codemirror_mode": {
|
49 |
+
"name": "ipython",
|
50 |
+
"version": 3
|
51 |
+
},
|
52 |
+
"file_extension": ".py",
|
53 |
+
"mimetype": "text/x-python",
|
54 |
+
"name": "python",
|
55 |
+
"nbconvert_exporter": "python",
|
56 |
+
"pygments_lexer": "ipython3",
|
57 |
+
"version": "3.11.10"
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"nbformat": 4,
|
61 |
+
"nbformat_minor": 5
|
62 |
+
}
|
ai_scientist/__init__.py
ADDED
File without changes
|
ai_scientist/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (166 Bytes). View file
|
|
ai_scientist/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (161 Bytes). View file
|
|
ai_scientist/__pycache__/generate_ideas.cpython-311.pyc
ADDED
Binary file (26 kB). View file
|
|
ai_scientist/__pycache__/generate_ideas.cpython-312.pyc
ADDED
Binary file (22.7 kB). View file
|
|
ai_scientist/__pycache__/llm.cpython-311.pyc
ADDED
Binary file (10.2 kB). View file
|
|
ai_scientist/__pycache__/llm.cpython-312.pyc
ADDED
Binary file (9.02 kB). View file
|
|
ai_scientist/__pycache__/perform_experiments.cpython-311.pyc
ADDED
Binary file (8.02 kB). View file
|
|
ai_scientist/__pycache__/perform_experiments.cpython-312.pyc
ADDED
Binary file (7.43 kB). View file
|
|
ai_scientist/__pycache__/perform_review.cpython-311.pyc
ADDED
Binary file (22.6 kB). View file
|
|
ai_scientist/__pycache__/perform_review.cpython-312.pyc
ADDED
Binary file (21.2 kB). View file
|
|
ai_scientist/__pycache__/perform_writeup.cpython-311.pyc
ADDED
Binary file (34.1 kB). View file
|
|
ai_scientist/__pycache__/perform_writeup.cpython-312.pyc
ADDED
Binary file (29.9 kB). View file
|
|
ai_scientist/fewshot_examples/132_automated_relational.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"review": "{\n \"Summary\": \"The paper provides an interesting direction in the meta-learning field. In particular, it proposes to enhance meta learning performance by fully exploring relations across multiple tasks. To capture such information, the authors develop a heterogeneity-aware meta-learning framework by introducing a novel architecture--meta-knowledge graph, which can dynamically find the most relevant structure for new tasks.\",\n \"Strengths\": [\n \"The paper takes one of the most important issues of meta-learning: task heterogeneity. For me, the problem itself is real and practical.\",\n \"The proposed meta-knowledge graph is novel for capturing the relation between tasks and addressing the problem of task heterogeneity. Graph structure provides a more flexible way of modeling relations. The design for using the prototype-based relational graph to query the meta-knowledge graph is reasonable and interesting.\",\n \"This paper provides comprehensive experiments, including both qualitative analysis and quantitative results, to show the effectiveness of the proposed framework. The newly constructed Art-Multi dataset further enhances the difficulty of tasks and makes the performance more convincing.\"\n ],\n \"Weaknesses\": [\n \"Although the proposed method provides several ablation studies, I still suggest the authors conduct the following ablation studies to enhance the quality of the paper: (1) It might be valuable to investigate the modulation function. In the paper, the authors compare sigmoid, tanh, and Film layer. Can the authors analyze the results by reducing the number of gating parameters in Eq. 10 by sharing the gate value of each filter in Conv layers? (2) What is the performance of the proposed model by changing the type of aggregators?\",\n \"For the autoencoder aggregator, it would be better to provide more details about it, which seems not very clear to me.\",\n \"In the qualitative analysis (i.e., Figure 2 and Figure 3), the authors provide one visualization for each task. It would be more convincing if the authors can provide more cases in the rebuttal period.\"\n ],\n \"Originality\": 3,\n \"Quality\": 3,\n \"Clarity\": 3,\n \"Significance\": 4,\n \"Questions\": [\n \"Please address and clarify the cons above.\"\n ],\n \"Limitations\": [\n \"My major concern is about the clarity of the paper and some additional ablation models (see cons below). Hopefully the authors can address my concern in the rebuttal period.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 3,\n \"Presentation\": 3,\n \"Contribution\": 3,\n \"Overall\": 7,\n \"Confidence\": 5,\n \"Decision\": \"Accept\"\n}"
|
3 |
+
}
|
ai_scientist/fewshot_examples/132_automated_relational.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a29ed4d84f6be5b9547097c2bc8bd57bfe197e91dc8f4ec9bcde6b545e7abe59
|
3 |
+
size 1348476
|
ai_scientist/fewshot_examples/132_automated_relational.txt
ADDED
@@ -0,0 +1,1190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AUTOMATED RELATIONAL META-LEARNING
|
2 |
+
|
3 |
+
**Anonymous authors**
|
4 |
+
Paper under double-blind review
|
5 |
+
|
6 |
+
ABSTRACT
|
7 |
+
|
8 |
+
In order to efficiently learn with small amount of data on new tasks, meta-learning
|
9 |
+
transfers knowledge learned from previous tasks to the new ones. However, a
|
10 |
+
critical challenge in meta-learning is the task heterogeneity which cannot be well
|
11 |
+
handled by traditional globally shared meta-learning methods. In addition, current
|
12 |
+
task-specific meta-learning methods may either suffer from hand-crafted structure
|
13 |
+
design or lack the capability to capture complex relations between tasks. In this
|
14 |
+
paper, motivated by the way of knowledge organization in knowledge bases, we
|
15 |
+
propose an automated relational meta-learning (ARML) framework that automatically extracts the cross-task relations and constructs the meta-knowledge graph.
|
16 |
+
When a new task arrives, it can quickly find the most relevant structure and tailor
|
17 |
+
the learned structure knowledge to the meta-learner. As a result, the proposed
|
18 |
+
framework not only addresses the challenge of task heterogeneity by a learned
|
19 |
+
meta-knowledge graph, but also increases the model interpretability. We conduct
|
20 |
+
extensive experiments on 2D toy regression and few-shot image classification and
|
21 |
+
the results demonstrate the superiority of ARML over state-of-the-art baselines.
|
22 |
+
|
23 |
+
1 INTRODUCTION
|
24 |
+
|
25 |
+
Learning quickly with a few samples is the key characteristic of human intelligence, which remains a
|
26 |
+
daunting problem in machine intelligence. The mechanism of learning to learn (a.k.a., meta-learning)
|
27 |
+
is widely used to generalize and transfer prior knowledge learned from previous tasks to improve
|
28 |
+
the effectiveness of learning on new tasks, which has benefited various applications, ranging from
|
29 |
+
computer vision (Kang et al., 2019; Liu et al., 2019) to natural language processing (Gu et al., 2018;
|
30 |
+
Lin et al., 2019). Most of existing meta-learning algorithms learn a globally shared meta-learner
|
31 |
+
(e.g., parameter initialization (Finn et al., 2017), meta-optimizer (Ravi & Larochelle, 2016), metric
|
32 |
+
space (Snell et al., 2017)). However, globally shared meta-learners fail to handle tasks lying in
|
33 |
+
different distributions, which is known as task heterogeneity. Task heterogeneity has been regarded as
|
34 |
+
one of the most challenging issues in few-shot learning, and thus it is desirable to design meta-learning
|
35 |
+
models that effectively optimize each of the heterogeneous tasks.
|
36 |
+
|
37 |
+
The key challenge to deal with task heterogeneity is how to customize globally shared meta-learner
|
38 |
+
by using task-aware information? Recently, a handful of works try to solve the problem by learning
|
39 |
+
a task-specific representation for tailoring the transferred knowledge to each task (Oreshkin et al.,
|
40 |
+
2018; Vuorio et al., 2019; Lee & Choi, 2018). However, the success of these methods relies on the
|
41 |
+
impaired knowledge generalization among closely correlated tasks (e.g., the tasks sampled from the
|
42 |
+
same distribution). Recently, learning the underlying structure among tasks provide a more effective
|
43 |
+
way for balancing the customization and generalization. Representatively, Yao et al. propose a
|
44 |
+
hierarchically structured meta-learning method to customize the globally shared knowledge to each
|
45 |
+
cluster in a hierarchical way (Yao et al., 2019). Nonetheless, the hierarchical clustering structure
|
46 |
+
completely relies on the handcrafted design which needs to be tuned carefully and may lack the
|
47 |
+
capability to capture complex relationships.
|
48 |
+
|
49 |
+
Hence, we are motivated to propose a framework to automatically extract underlying relational
|
50 |
+
structures from previously learned tasks and leverage those relational structures to facilitate knowledge
|
51 |
+
customization on a new task. This inspiration comes from the way of structuring knowledge in
|
52 |
+
knowledge bases (i.e., knowledge graphs). In knowledge bases, the underlying relational structures
|
53 |
+
across text entities are automatically constructed and applied to a new query to improve the searching
|
54 |
+
efficiency. In the meta-learning problem, similarly, we aim at automatically establishing the metaknowledge graph between prior knowledge learned from previous tasks. When a new task arrives,
|
55 |
+
it queries the meta-knowledge graph and quickly attends to the most relevant entities (nodes), and
|
56 |
+
then takes advantage of the relational knowledge structures between them to boost the learning
|
57 |
+
effectiveness with the limited training data.
|
58 |
+
|
59 |
+
|
60 |
+
-----
|
61 |
+
|
62 |
+
The proposed meta-learning framework is named as Automated Relational Meta-Learning (ARML).
|
63 |
+
Specifically, the ARML framework automatically builds the meta-knowledge graph from metatraining tasks to memorize and organize learned knowledge from historical tasks, where each vertex
|
64 |
+
represent one type of meta-knowledge (e.g., the common contour between birds and aircrafts). To
|
65 |
+
learn the meta-knowledge graph at meta-training time, for each task, we construct a prototype-based
|
66 |
+
relational graph for each class, where each vertex represents one prototype. The prototype-based
|
67 |
+
relational graph not only captures the underlying relationship behind samples, but alleviates the
|
68 |
+
potential effects of abnormal samples. The meta-knowledge graph is then learned by and summarizing
|
69 |
+
the information from the corresponding prototype-based relational graphs of meta-training tasks.
|
70 |
+
After constructing the meta-knowledge graph, when a new task comes in, the prototype-based
|
71 |
+
relational graph of the new task taps into the meta-knowledge graph for acquiring the most relevant
|
72 |
+
knowledge, which further enhances the task representation and facilitates its training process.
|
73 |
+
|
74 |
+
Our major contributions of the proposed ARML are three-fold: (1) it automatically constructs the
|
75 |
+
meta-knowledge graph to facilitate learning a new task; (2) it empirically outperforms state-of-the-art
|
76 |
+
meta-learning algorithms; (3) the meta-knowledge graph well captures the relationship among tasks
|
77 |
+
and improves the interpretability of meta-learning algorithms.
|
78 |
+
|
79 |
+
2 RELATED WORK
|
80 |
+
|
81 |
+
Meta-learning, allowing machines to learn new skills or adapt to new environments rapidly with a
|
82 |
+
few training examples, has been demonstrated to be successful in both supervised learning tasks
|
83 |
+
(e.g., few-shot image classification) and reinforcement learning settings. There are mainly three
|
84 |
+
research lines of meta-learning: (1) black-box amortized methods design black-box meta-learners
|
85 |
+
(e.g., neural networks) to infer the model parameters (Ravi & Larochelle, 2016; Andrychowicz et al.,
|
86 |
+
2016; Mishra et al., 2018); (2) gradient-based methods aim to learn an optimized initialization of
|
87 |
+
model parameters, which can be adapted to new tasks by a few steps of gradient descent (Finn et al.,
|
88 |
+
2017; 2018; Lee & Choi, 2018); (3) non-parameteric methods combine parameteric meta-learners
|
89 |
+
and non-parameteric learners to learn an appropriate distance metric for few-shot classification (Snell
|
90 |
+
et al., 2017; Vinyals et al., 2016; Yang et al., 2018; Oreshkin et al., 2018; Yoon et al., 2019).
|
91 |
+
|
92 |
+
Our work is built upon the gradient-based meta-learning methods. In the line of gradient-based
|
93 |
+
meta-learning, most algorithms learn a globally shared meta-learners from all previous tasks (Finn
|
94 |
+
et al., 2017; Li et al., 2017; Flennerhag et al., 2019), to improve the effectiveness of learning process
|
95 |
+
on new tasks. However, these algorithms typically lack the ability to handle heterogeneous tasks
|
96 |
+
(i.e., tasks sample from sufficient different distributions). To tackle this challenge, recent works
|
97 |
+
tailor the globally shared initialization to different tasks by leveraging task-specific information (Lee
|
98 |
+
& Choi, 2018; Vuorio et al., 2019; Oreshkin et al., 2018) and using probabilistic models (Grant
|
99 |
+
et al., 2018; Yoon et al., 2018; Gordon et al., 2019). Recently, HSML customizes the global shared
|
100 |
+
initialization with a manually designed hierarchical clustering structure to balance the generalization
|
101 |
+
and customization between previous tasks (Yao et al., 2019). However, the hierarchical structure
|
102 |
+
may not accurately reflect the real structure since it highly relies on the hand-crafted design. In
|
103 |
+
addition, the clustering structure further constricts the complexity of relational structures. However, to
|
104 |
+
customize each task, our proposed ARML leverages the most relevant structure from meta-knowledge
|
105 |
+
graph which are automatically constructed by previous knowledge. Thus, ARML not only discovers
|
106 |
+
more accurate underlying structures to improve the effectiveness of meta-learning algorithms, but
|
107 |
+
also the meta-knowledge graph can further enhance the model interpretability.
|
108 |
+
|
109 |
+
3 PRELIMINARIES
|
110 |
+
|
111 |
+
**Few-shot Learning** Considering a task Ti, the goal of few-shot learning is to learn a model with
|
112 |
+
a dataset Di = {Di[tr][,][ D]i[ts][}][, where the labeled training set][ D]i[tr] = {x[tr]j _[,][ y]j[tr][|∀][j][ ∈]_ [[1][, N][ tr][]][}][ only has a]
|
113 |
+
few samples and Di[ts] [represents the corresponding test set. A learning model (a.k.a., base model)][ f]
|
114 |
+
with parameters θ are used to evaluate the effectiveness on Di[ts] [by minimizing the expected empirical]
|
115 |
+
loss on Di[tr][, i.e.,][ L][(][D]T[tr]i _[, θ][)][, and obtain the optimal parameters][ θ][i][. For the regression problem, the loss]_
|
116 |
+
function is defined based on the mean square error (i.e., (xj _,yj_ )∈Di[tr] 2[) and for the clas-]
|
117 |
+
|
118 |
+
sification problem, the loss function uses the cross entropy loss (i.e., −[∥][f][P][θ][(]([x]x[j]j[)],y[−]j )[y]∈D[j][∥]i[tr][2] [log][ p][(][y][j][|][x][j][, f][θ][)][).]
|
119 |
+
|
120 |
+
Usually, optimizing and learning parameter θ for the task[P] _Ti with a few labeled training samples_
|
121 |
+
is difficult. To address this limitation, meta-learning provides us a new perspective to improve the
|
122 |
+
performance by leveraging knowledge from multiple tasks.
|
123 |
+
|
124 |
+
|
125 |
+
-----
|
126 |
+
|
127 |
+
**Meta-learning and Model-agnostic Meta-learning** In meta-learning, a sequence of tasks
|
128 |
+
_{T1, ..., TI_ _} are sampled from a task-level probability distribution p(T ), where each one is a few-shot_
|
129 |
+
learning task. To facilitate the adaption for incoming tasks, the meta-learning algorithm aims to find
|
130 |
+
a well-generalized meta-learner on I training tasks at meta-learning phase. At meta-testing phase, the
|
131 |
+
optimal meta-learner is applied to adapt the new tasks Tt. In this way, meta-learning algorithms are
|
132 |
+
capable of adapting to new tasks efficiently even with a shortage of training data for a new task.
|
133 |
+
|
134 |
+
Model-agnostic meta-learning (MAML) (Finn et al., 2017), one of the representative algorithms in
|
135 |
+
gradient-based meta-learning, regards the meta-learner as the initialization of parameter θ, i.e., θ0,
|
136 |
+
and learns a well-generalized initialization θ0[∗] [during the meta-training process. The optimization]
|
137 |
+
problem is formulated as (one gradient step as exemplary):
|
138 |
+
|
139 |
+
|
140 |
+
_θ0[∗]_ [:= arg min]
|
141 |
+
_θ0_
|
142 |
+
|
143 |
+
|
144 |
+
(fθi _,_ _i_ [) = arg min]
|
145 |
+
_L_ _D[ts]_ _θ0_
|
146 |
+
_i=1_
|
147 |
+
|
148 |
+
X
|
149 |
+
|
150 |
+
|
151 |
+
_L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][)][.] (1)
|
152 |
+
_i=1_
|
153 |
+
|
154 |
+
X
|
155 |
+
|
156 |
+
|
157 |
+
At the meta-testing phase, to obtain the adaptive parameter θt for each new task Tt, we finetune the
|
158 |
+
initialization of parameter θ0[∗] [by performing gradient updates a few steps, i.e.,][ f]θt [=][ f]θ0[∗] _t_ [)][.]
|
159 |
+
|
160 |
+
_[−][α][∇][θ]_ _[L][(][f][θ]_ _[,][D][tr]_
|
161 |
+
|
162 |
+
4 METHODOLOGY
|
163 |
+
|
164 |
+
In this section, we introduce the details of the proposed ARML. To better explain how it works,
|
165 |
+
we show its framework in Figure 1. The goal of ARML is to facilitate the learning process of new
|
166 |
+
tasks by leveraging transferable knowledge learned from historical tasks. To achieve this goal, we
|
167 |
+
introduce a meta-knowledge graph, which is automatically constructed at the meta-training time, to
|
168 |
+
organize and memorize historical learned knowledge. Given a task, which is built as a prototypebased relational structure, it taps into the meta-knowledge graph to acquire relevant knowledge for
|
169 |
+
enhancing its own representation. The enhanced prototype representation further aggregate and
|
170 |
+
incorporate with meta-learner for fast and effective adaptions by utilizing a modulating function. In
|
171 |
+
the following subsections, we elaborate three key components: prototype-based sample structuring,
|
172 |
+
automated meta-knowledge graph construction and utilization, and task-specific knowledge fusion
|
173 |
+
and adaptation, respectively.
|
174 |
+
|
175 |
+
**Propagation**
|
176 |
+
|
177 |
+
**Prototype-based** **Meta-knowledge**
|
178 |
+
|
179 |
+
**Prototypes** **Relational** **Graph )**
|
180 |
+
|
181 |
+
**Structure ℛ#**
|
182 |
+
|
183 |
+
+#(,
|
184 |
+
|
185 |
+
…
|
186 |
+
…
|
187 |
+
… !"
|
188 |
+
|
189 |
+
**Aggregator**
|
190 |
+
|
191 |
+
ℒ( **Modulation**
|
192 |
+
|
193 |
+
**Aggregator**
|
194 |
+
|
195 |
+
+#(- ℒ' ∇%ℒ !"#
|
196 |
+
|
197 |
+
!#
|
198 |
+
|
199 |
+
|
200 |
+
Figure 1: The framework of ARML. For each task _i, ARML first builds a prototype-based relational_
|
201 |
+
_T_
|
202 |
+
structure Ri by mapping the training samples Di[tr] [into prototypes, with each prototype represents]
|
203 |
+
one class. Then, Ri interacts with the meta-knowledge graph G to acquire the most relevant historical
|
204 |
+
knowledge by information propagation. Finally, the task-specific modulation tailors the globally
|
205 |
+
shared initialization θ0 by aggregating of raw prototypes and enriched prototypes, which absorbs
|
206 |
+
relevant historical information from the meta-knowledge graph.
|
207 |
+
|
208 |
+
4.1 PROTOTYPE-BASED SAMPLE STRUCTURING
|
209 |
+
|
210 |
+
Given a task which involves either classifications or regressions regarding a set of samples, we first
|
211 |
+
investigate the relationships among these samples. Such relationship is represented by a graph, called
|
212 |
+
prototype-based relational graph in this work, where the vertices in the graph denote the prototypes
|
213 |
+
of different classes while the edges and the corresponding edge weights are created based on the
|
214 |
+
|
215 |
+
|
216 |
+
-----
|
217 |
+
|
218 |
+
similarities between prototypes. Constructing the relational graph based on prototypes instead of raw
|
219 |
+
samples allows us to alleviate the issue raised by abnormal samples. As the abnormal samples, which
|
220 |
+
locate far away from normal samples, could pose significant concerns especially when only a limited
|
221 |
+
number of samples are available for training. Specifically, for classification problem, the prototype,
|
222 |
+
denoted by c[k]i
|
223 |
+
|
224 |
+
_[∈]_ [R][d][, is defined as:] _N_ _[tr]_
|
225 |
+
|
226 |
+
|
227 |
+
**c[k]i** [=]
|
228 |
+
|
229 |
+
|
230 |
+
_E(xj),_ (2)
|
231 |
+
_j=1_
|
232 |
+
|
233 |
+
X
|
234 |
+
|
235 |
+
|
236 |
+
_Nk[tr]_
|
237 |
+
|
238 |
+
|
239 |
+
where Nk[tr] [denotes the number of samples in class][ k][.][ E][ is an embedding function, which projects]
|
240 |
+
**xj into a hidden space where samples from the same class are located closer to each other while**
|
241 |
+
samples from different classes stay apart. For regression problem, it is not straightforward to construct
|
242 |
+
the prototypes explicitly based on class information. Therefore, we cluster samples by learning an
|
243 |
+
assignment matrix Pi R[K][×][N] _[tr]_ . Specifically, we formulate the process as:
|
244 |
+
_∈_
|
245 |
+
|
246 |
+
**Pi = Softmax(WpE** [T](X) + bp), c[k]i [=][ P]i[[][k][]][F] [(][X][)][,] (3)
|
247 |
+
|
248 |
+
where Pi[k] represents the k-th row of Pi. Thus, training samples are clustered to K clusters, which
|
249 |
+
serve as the representation of prototypes.
|
250 |
+
|
251 |
+
After calculating all prototype representations **c[k]i**
|
252 |
+
_{_ _[|∀][k][ ∈]_ [[1][, K][]][}][, which serve as the vertices in the the]
|
253 |
+
prototype-based relational graph Ri, we further define the edges and the corresponding edge weights.
|
254 |
+
The edge weight ARi (c[j]i _[,][ c]i[m][)][ between two prototypes][ c]i[j]_ [and][ c]i[m] [is gauged by the the similarity]
|
255 |
+
between them. Formally:
|
256 |
+
|
257 |
+
_ARi_ (c[j]i _[,][ c]i[m][) =][ σ][(][W]r[(][|][c][j]i_ _i_ _r[) +][ b]r[)][,]_ (4)
|
258 |
+
|
259 |
+
_[−]_ **[c][m][|][/γ]**
|
260 |
+
|
261 |
+
where Wr and br represents learnable parameters, γr is a scalar and σ is the Sigmoid function, which
|
262 |
+
normalizes the weight between 0 and 1. For simplicity, we denote the prototype-based relational graph
|
263 |
+
as Ri = (CRi _, ARi_ ), where CRi = {c[j]i _[|∀][j][ ∈]_ [[1][, K][]][} ∈] [R][K][×][d][ represent a set of vertices, with each]
|
264 |
+
one corresponds to the prototype from a class, while ARi = {|ARi (c[j]i _[,][ c]i[m][)][|∀][j, m][ ∈]_ [[1][, K][]][} ∈] [R][K][×][K]
|
265 |
+
gives the adjacency matrix, which indicates the proximity between prototypes.
|
266 |
+
|
267 |
+
4.2 AUTOMATED META-KNOWLEDGE GRAPH CONSTRUCTION AND UTILIZATION
|
268 |
+
|
269 |
+
In this section, we first discuss how to organize and distill knowledge from historical learning process
|
270 |
+
and then expound how to leverage such knowledge to benefit the training of new tasks. To organize
|
271 |
+
and distill knowledge from historical learning process, we construct and maintain a meta-knowledge
|
272 |
+
graph. The vertices represent different types of meta-knowledge (e.g., the common contour between
|
273 |
+
aircrafts and birds) and the edges are automatically constructed and reflect the relationship between
|
274 |
+
meta-knowledge. When serving a new task, we refer to the meta-knowledge, which allows us to
|
275 |
+
efficiently and automatically identify relational knowledge from previous tasks. In this way, the
|
276 |
+
training of a new task can benefit from related training experience and get optimized much faster
|
277 |
+
than otherwise possible. In this paper, the meta-knowledge graph is automatically constructed at the
|
278 |
+
meta-training phase. The details of the construction are elaborated as follows:
|
279 |
+
|
280 |
+
Assuming the representation of an vertex g is given by h[g] _∈_ R[d], we define the meta-knowledge
|
281 |
+
graph as G = (HG, AG), where HG = {h[j]|∀j ∈ [1, G]} ∈ R[G][×][d] and AG = {AG(h[j], h[m])|∀j, m ∈
|
282 |
+
|
283 |
+
[1, G]} ∈ R[G][×][G] denote the vertex feature matrix and vertex adjacency matrix, respectively. To better
|
284 |
+
explain the construction of the meta-knowledge graph, we first discuss the vertex representation H .
|
285 |
+
_G_
|
286 |
+
During meta-training, tasks arrive one after another in a sequence and their corresponding vertices
|
287 |
+
representations are expected to be updated dynamically in a timely manner. Therefore, the vertex
|
288 |
+
representation of meta-knowledge graph are defined to get parameterized and learned at the training
|
289 |
+
time. Moreover, to encourage the diversity of meta-knowledge encoded in the meta-knowledge graph,
|
290 |
+
the vertex representations are randomly initialized. Analogous to the definition of weight in the
|
291 |
+
prototype-based relational graph Ri in equation 4, the weight between a pair of vertices j and m is
|
292 |
+
constructed as:
|
293 |
+
_A_ (h[j], h[m]) = σ(Wo( **h[j]** **h[m]** _/γo) + bo),_ (5)
|
294 |
+
_G_ _|_ _−_ _|_
|
295 |
+
where Wo and bo represent learnable parameters and γo is a scalar.
|
296 |
+
|
297 |
+
To enhance the learning of new tasks with involvement of historical knowledge, we query the
|
298 |
+
prototype-based relational graph in the meta-knowledge graph to obtain the relevant knowledge in
|
299 |
+
history. The ideal query mechanism is expected to optimize both graph representations simultaneously
|
300 |
+
|
301 |
+
|
302 |
+
-----
|
303 |
+
|
304 |
+
at the meta-training time, with the training of one graph facilitating the training of the other. In light
|
305 |
+
of this, we construct a super-graph Si by connecting the prototype-based relational graph Ri with the
|
306 |
+
meta-knowledge graph G for each task Ti. The union of the vertices in Ri and G contributes to the
|
307 |
+
vertices in the super-graph. The edges in Ri and G are also reserved in the super-graph. We connect
|
308 |
+
_Ri with G by creating links between the prototype-based relational graph with the meta-knowledge_
|
309 |
+
graph. The link between prototype c[j]i [in prototype-based relational graph and vertex][ h][m][ in meta-]
|
310 |
+
knowledge graph is weighted by the similarity between them. More precisely, for each prototype c[j]i [,]
|
311 |
+
the link weight AS (c[j]i _[,][ h][m][)][ is calculated by applying softmax over Euclidean distances between][ c][j]i_
|
312 |
+
and {h[m]|∀m ∈ [1, G]} as follows:
|
313 |
+
|
314 |
+
_AS_ (c[j]i _[,][ h][k][) =]_ _Kexp(−∥(c[j]i_ _[−]_ **[h][k][)][/γ][s][∥]2[2][/][2)]** _,_ (6)
|
315 |
+
_k[′]_ =1 [exp(][−∥][(][c]i[j] _[−]_ **[h][k][′][ )][/γ][s][∥]2[2][/][2)]**
|
316 |
+
|
317 |
+
where γs is a scaling factor. We denote the intra-adjacent matrix asP **AS = {AS** (c[j]i _[,][ h][m][)][|∀][j][ ∈]_
|
318 |
+
|
319 |
+
[1, K], m ∈ [1, G]} ∈ R[K][×][G]. Thus, for task Ti, the adjacent matrix and feature matrix of super-graph
|
320 |
+
_i = (Ai, Hi) is defined as Ai = (A_ _i_ _, A_ ; A[T] [= (][C][R]i [;][ H][G][)][ ∈]
|
321 |
+
_S_ _R_ _S_ _S_ _[,][ A][G][)][ ∈]_ [R][(][K][+][G][)][×][(][K][+][G][)][ and][ H][i]
|
322 |
+
R[(][K][+][G][)][×][d], respectively.
|
323 |
+
|
324 |
+
After constructing the super-graph Si, we are able to propagate the most relevant knowledge from
|
325 |
+
meta-knowledge graph G to the prototype-based relational graph Ri by introducing a Graph Neural
|
326 |
+
Networks (GNN). In this work, following the “message-passing” framework (Gilmer et al., 2017),
|
327 |
+
the GNN is formulated as:
|
328 |
+
**Hi[(][l][+1)]** = MP(Ai, H[(]i[l][)][;][ W][(][l][)][)][,] (7)
|
329 |
+
where MP(·) is the message passing function and has several possible implementations (Hamilton
|
330 |
+
et al., 2017; Kipf & Welling, 2017; Velickoviˇ c et al., 2018),´ **H[(]i[l][)]** is the vertex embedding after l
|
331 |
+
layers of GNN and W[(][l][)] is a learnable weight matrix of layer l. The input H[(0)]i = Hi. After stacking
|
332 |
+
_L GNN layers, we get the information-propagated feature representation for the prototype-based_
|
333 |
+
relational graph Ri as the top-K rows of Hi[(][L][)], which is denoted as **C[ˆ]** _Ri = {cˆ[j]i_ _[|][j][ ∈]_ [[1][, K][]][}][.]
|
334 |
+
|
335 |
+
4.3 TASK-SPECIFIC KNOWLEDGE FUSION AND ADAPTATION
|
336 |
+
|
337 |
+
After propagating information form meta-knowledge graph to prototype-based relational graph, in
|
338 |
+
this section, we discuss how to learn a well-generalized meta-learner for fast and effective adaptions
|
339 |
+
to new tasks with limited training data. To tackle the challenge of task heterogeneity, in this
|
340 |
+
paper, we incorporate task-specific information to customize the globally shared meta-learner (e.g.,
|
341 |
+
initialization here) by leveraging a modulating function, which has been proven to be effective to
|
342 |
+
provide customized initialization in previous studies (Wang et al., 2019; Vuorio et al., 2019).
|
343 |
+
|
344 |
+
The modulating function relies on well-discriminated task representations, while it is difficult to learn
|
345 |
+
all representations by merely utilizing the loss signal derived from the test set Di[ts][. To encourage such]
|
346 |
+
stability, we introduce two reconstructions by utilizing two auto-encoders. There are two collections
|
347 |
+
of parameters, i.e, CRi and **C[ˆ]** _Ri, which contribute the most to the creation of the task-specific_
|
348 |
+
meta-learner. CRi express the raw prototype information without tapping into the meta-knowledge
|
349 |
+
graph, while **C[ˆ]** _Ri give the prototype representations after absorbing the relevant knowledge from the_
|
350 |
+
meta-knowledge graph. Therefore, the two reconstructions are built on CRi and **C[ˆ]** _Ri_ . To reconstruct
|
351 |
+
**CRi**, an aggregator AG[q](·) (e.g., recurrent network, fully connected layers) is involved to encode CRi
|
352 |
+
into a dense representation, which is further fed into a decoder AG[q]dec[(][·][)][ to achieve reconstructions.]
|
353 |
+
Then, the corresponded task representation qi of CRi is summarized by applying a mean pooling
|
354 |
+
operator over prototypes on the encoded dense representation. Formally,
|
355 |
+
|
356 |
+
_N_ _[tr]_
|
357 |
+
|
358 |
+
|
359 |
+
**qi = MeanPool(AG[q](CRi** )) =
|
360 |
+
|
361 |
+
|
362 |
+
(AG[q](c[j]i [))][,][ L][q][ =][ ∥][C][R]i _dec[(AG][q][(][C][R]i_ [))][∥]F[2] (8)
|
363 |
+
_j=1_ _[−]_ [AG][q]
|
364 |
+
|
365 |
+
X
|
366 |
+
|
367 |
+
|
368 |
+
_N_ _[tr]_
|
369 |
+
|
370 |
+
|
371 |
+
Similarly, we reconstruct **C[ˆ]** _Ri and get the corresponded task representation ti as follows:_
|
372 |
+
|
373 |
+
_N_ _[tr]_
|
374 |
+
|
375 |
+
|
376 |
+
**ti = MeanPool(AG[t]( C[ˆ]** _Ri_ )) =
|
377 |
+
|
378 |
+
|
379 |
+
_j=1(AG[t](ˆc[j]i_ [))][,][ L][t][ =][ ∥]C[ˆ] _Ri −_ AG[t]dec[(AG][t][( ˆ]CRi ))∥F[2] (9)
|
380 |
+
|
381 |
+
X
|
382 |
+
|
383 |
+
|
384 |
+
_N_ _[tr]_
|
385 |
+
|
386 |
+
|
387 |
+
The reconstruction errors in Equations 8 and 9 pose an extra constraint to enhance the training
|
388 |
+
stability, leading to improvement of task representation learning.
|
389 |
+
|
390 |
+
|
391 |
+
-----
|
392 |
+
|
393 |
+
**Algorithm 1 Meta-Training Process of ARML**
|
394 |
+
|
395 |
+
**Require: p(T ): distribution over tasks; K: Number of vertices in meta-knowledge graph; α: stepsize**
|
396 |
+
for gradient descent of each task (i.e., inner loop stepsize); β: stepsize for meta-optimization (i.e.,
|
397 |
+
outer loop stepsize); µ1, µ2: balancing factors in loss function
|
398 |
+
|
399 |
+
1: Randomly initialize all learnable parameters Φ
|
400 |
+
2: while not done do
|
401 |
+
3: Sample a batch of tasks {Ti|i ∈ [1, I]} from p(T )
|
402 |
+
|
403 |
+
4: **for all Ti do**
|
404 |
+
|
405 |
+
5: Sample training set Di[tr] [and testing set][ D]i[ts]
|
406 |
+
|
407 |
+
6: Construct the prototype-based relational graph Ri by computing prototype in equation 2
|
408 |
+
and weight in equation 4
|
409 |
+
|
410 |
+
7: Compute the similarity between each prototype and meta-knowledge vertex in equation 6
|
411 |
+
and construct the super-graph Si
|
412 |
+
|
413 |
+
8: Apply GNN on super-graph Si and get the information-propagated representation **C[ˆ]** _Ri_
|
414 |
+
|
415 |
+
9: Aggregate CRi in equation 8 and **C[ˆ]** _Ri in equation 9 to get the representations qi, ti and_
|
416 |
+
reconstruction loss Lq, Lt
|
417 |
+
|
418 |
+
10: Compute the task-specific initialization θ0i in equation 10 and update parameters θi =
|
419 |
+
_θ0i −_ _α∇θL(fθ, Di[tr][)]_
|
420 |
+
|
421 |
+
11: **end for**
|
422 |
+
|
423 |
+
12: Update Φ Φ _β_ Φ _Ii=1_ _i_ _[,][ D]i[ts][) +][ µ][i][L][t]_ [+][ µ][2][L][q]
|
424 |
+
|
425 |
+
13: end while _←_ _−_ _∇_ _[L][(][f][θ]_
|
426 |
+
|
427 |
+
P
|
428 |
+
|
429 |
+
|
430 |
+
After getting the task representation qi and ti, the modulating function is then used to tailor the
|
431 |
+
task-specific information to the globally shared initialization θ0, which is formulated as:
|
432 |
+
|
433 |
+
_θ0i = σ(Wg(ti ⊕_ **qi) + bg) ◦** _θ0,_ (10)
|
434 |
+
|
435 |
+
where Wg and bg is learnable parameters of a fully connected layer. Note that we adopt the Sigmoid
|
436 |
+
gating as exemplary and more discussion about different modulating functions can be found in
|
437 |
+
ablation studies of Section 5.
|
438 |
+
|
439 |
+
For each task Ti, we perform the gradient descent process from θ0i and reach its optimal parameter θi.
|
440 |
+
Combining the reconstruction loss Lt and Lq with the meta-learning loss defined in equation 1, the
|
441 |
+
overall objective function of ARML is:
|
442 |
+
|
443 |
+
_I_
|
444 |
+
|
445 |
+
minΦ Φ Φ _L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][) +][ µ]1[L]t [+][ µ]2[L]q[,] (11)
|
446 |
+
|
447 |
+
_[L][all][ = min]_ _[L][ +][ µ][1][L][t][ +][ µ][2][L][q][ = min]_ _i=1_
|
448 |
+
|
449 |
+
X
|
450 |
+
|
451 |
+
where µ1 and µ2 are introduced to balance the importance of these three items. Φ represents all
|
452 |
+
learnable parameters. The algorithm of meta-training process of ARML is shown in Alg. 2. The
|
453 |
+
details of the meta-testing process of ARML are available in Appendix A.
|
454 |
+
|
455 |
+
5 EXPERIMENTS
|
456 |
+
|
457 |
+
In this section, we conduct extensive experiments to demonstrate the effectiveness of the ARML on
|
458 |
+
2D regression and few-shot classification with the goal of answering the following questions: (1) Can
|
459 |
+
ARML outperform other meta-learning methods?; (2) Can our proposed components improve the
|
460 |
+
learning performance?; (3) Can ARML framework improve the model interpretability by discovering
|
461 |
+
reasonable meta-knowledge graph?
|
462 |
+
|
463 |
+
5.1 EXPERIMENTAL SETTINGS
|
464 |
+
|
465 |
+
**Methods for Comparison** We compare our proposed ARML with two types of baselines: gradientbased meta-learning algorithms and non-parameteric meta-learning algorithms.
|
466 |
+
|
467 |
+
_For gradient-based meta-learning methods: both globally shared methods (MAML (Finn et al.,_
|
468 |
+
2017), Meta-SGD (Li et al., 2017)) and task-specific methods (MT-Net (Lee & Choi, 2018), MUMOMAML (Vuorio et al., 2019), HSML (Yao et al., 2019)) are considered for comparison.
|
469 |
+
|
470 |
+
_For non-parametric meta-learning methods: we select globally shared method Prototypical Network_
|
471 |
+
(ProtoNet) (Snell et al., 2017) and task-specific method TADAM (Oreshkin et al., 2018) as baselines.
|
472 |
+
Note that, following the traditional settings, non-parametric baselines are only used in few-shot
|
473 |
+
classification problem. The detailed implementations of baselines are discussed in Appendix B.3.
|
474 |
+
|
475 |
+
|
476 |
+
-----
|
477 |
+
|
478 |
+
**Hyperparameter Settings** For the aggregated function in autoencoder structure (AG[t], AG[t]dec
|
479 |
+
AG[q], AG[q]dec[), we use the GRU as the encoder and decoder in this autoencoder framework. We]
|
480 |
+
adopt one layer GCN (Kipf & Welling, 2017) with tanh activation as the implementation of GNN
|
481 |
+
in equation 7. For the modulation network, we try both sigmoid, tanh Film modulation and find that
|
482 |
+
sigmoid modulation performs better. Thus, in the future experiment, we use the sigmoid modulation as
|
483 |
+
modulating function. More detailed discussion about experiment settings are presented in Appendix B.
|
484 |
+
|
485 |
+
5.2 2D REGRESSION
|
486 |
+
|
487 |
+
**Dataset Description.** In 2D regression problem, we adopt the similar regression problem settings
|
488 |
+
as (Finn et al., 2018; Vuorio et al., 2019; Yao et al., 2019; Rusu et al., 2019), which includes several
|
489 |
+
families of functions. In this paper, to model more complex relational structures, we design a 2D
|
490 |
+
regression problem rather than traditional 1D regression. Input x ∼ _U_ [0.0, 5.0] and y ∼ _U_ [0.0, 5.0]
|
491 |
+
are sampled randomly and random Gaussian noisy with standard deviation 0.3 is added to the
|
492 |
+
output. Furthermore, six underlying functions are selected, including (1) Sinusoids: z(x, y) =
|
493 |
+
_assin(wsx + bs), where as ∼_ _U_ [0.1, 5.0], bs ∼ _U_ [0, 2π] ws ∼ _U_ [0.8, 1.2]; (2) Line: z(x, y) = alx + bl,
|
494 |
+
where al ∼ _U_ [−3.0, 3.0], bl ∼ _U_ [−3.0, 3.0]; (3) Quadratic: z(x, y) = aqx[2] + bqx + cq, where aq ∼
|
495 |
+
_U_ [−0.2, 0.2], bq ∼ _U_ [−2.0, 2.0], cq ∼ _U_ [−3.0, 3.0]; (4) Cubic: z(x, y) = acx[3] + bcx[2] + ccx + dc,
|
496 |
+
where ac ∼ _U_ [−0.1, 0.1], bc ∼ _U_ [−0.2, 0.2], cc ∼ _U_ [−2.0, 2.0], dc ∼ _U_ [−3.0, 3.0]; (5) Quadratic
|
497 |
+
_Surface: z(x, y) = aqsx[2]_ + bqsy[2], where aqs ∼ _U_ [−1.0, 1.0], bqs ∼ _U_ [−1.0, 1.0]; (6) Ripple: z(x, y) =
|
498 |
+
_sin(−ar(x[2]_ + y[2])) + br, where ar ∼ _U_ [−0.2, 0.2], br ∼ _U_ [−3.0, 3.0]. Note that, function 1-4 are
|
499 |
+
located in the subspace of y = 1. Follow (Finn et al., 2017), we use two fully connected layers with
|
500 |
+
40 neurons as the base model. The number of vertices of meta-knowledge graph is set as 6.
|
501 |
+
|
502 |
+
**Results and Analysis.** In Figure 2, we summarize the interpretation of meta-knowledge graph
|
503 |
+
(see top figure) and the the qualitative results (see bottom table) of 10-shot 2D regression. In the
|
504 |
+
bottom table, we can observe that ARML achieves the best performance as compared to competitive
|
505 |
+
gradient-based meta-learning methods, i.e., globally shared models and task-specific models. This
|
506 |
+
finding demonstrates that the meta-knowledge graph is necessary to model and capture task-specific
|
507 |
+
information. The superior performance can also be interpreted in the top figure. In the left, we
|
508 |
+
show the heatmap between prototypes and meta-knowledge vertices (deeper color means higher
|
509 |
+
similarity). We can see that sinusoids and line activate V1 and V4, which may represent curve and
|
510 |
+
line, respectively. V1 and V4 also contribute to quadratic and quadratic surface, which also show
|
511 |
+
the similarity between these two families of functions. V3 is activated in P0 of all functions and the
|
512 |
+
quadratic surface and ripple further activate V1 in P0, which may show the different between 2D
|
513 |
+
functions and 3D functions (sinusoid, line, quadratic and cubic lie in the subspace). Specifically,
|
514 |
+
in the right figure, we illustrate the meta-knowledge graph, where we set a threshold to filter the
|
515 |
+
link with low similarity score and show the rest. We can see that V3 is the most popular vertice and
|
516 |
+
|
517 |
+
|Model|MAML Meta-SGD MT-Net MUMOMAML HSML ARML|
|
518 |
+
|---|---|
|
519 |
+
|
520 |
+
|
521 |
+
|10-shot|2.292 ± 0.163 2.908 ± 0.229 1.757 ± 0.120 0.523 ± 0.036 0.494 ± 0.038 0.438 ± 0.029|
|
522 |
+
|---|---|
|
523 |
+
|
524 |
+
|
525 |
+
connected with V1, V5 (represent curve) and V4 (represent line). V1 is further connected with V5,
|
526 |
+
demonstrating the similarity of curve representation.
|
527 |
+
|
528 |
+
V1
|
529 |
+
|
530 |
+
V2
|
531 |
+
|
532 |
+
Sinusoids Line
|
533 |
+
|
534 |
+
V0 V3
|
535 |
+
|
536 |
+
Quadratic Cubic
|
537 |
+
|
538 |
+
V5 V4
|
539 |
+
|
540 |
+
Quadratic Surface Ripple
|
541 |
+
|
542 |
+
Model MAML Meta-SGD MT-Net MUMOMAML HSML ARML
|
543 |
+
|
544 |
+
10-shot 2.292 0.163 2.908 0.229 1.757 0.120 0.523 0.036 0.494 0.038 **0.438** **0.029**
|
545 |
+
|
546 |
+
|
547 |
+
Figure 2: In the top figure, we show the interpretation of meta-knowledge graph. The left heatmap
|
548 |
+
shows the similarity between prototypes (P0, P1) and meta-knowledge vertices (V0-V5). The right
|
549 |
+
part show the meta-knowledge graph. In the bottom table, we show the overall performance (mean
|
550 |
+
square error with 95% confidence) of 10-shot 2D regression.
|
551 |
+
|
552 |
+
|
553 |
+
-----
|
554 |
+
|
555 |
+
5.3 FEW-SHOT CLASSIFICATION
|
556 |
+
|
557 |
+
**Dataset Description and Settings** In the few-shot classification problem, we first use the benchmark proposed in (Yao et al., 2019), where four fine-grained image classification datasets are included
|
558 |
+
(i.e., CUB-200-2011 (Bird), Describable Textures Dataset (Texture), FGVC of Aircraft (Aircraft),
|
559 |
+
and FGVCx-Fungi (Fungi)). For each few-shot classification task, it samples classes from one of four
|
560 |
+
datasets. In this paper, we call this dataset as Plain-Multi and each fine-grained dataset as subdataset.
|
561 |
+
|
562 |
+
Then, to demonstrate the effectiveness of our proposed model for handling more complex underlying
|
563 |
+
structures, in this paper, we increase the difficulty of few-shot classification problem by introducing
|
564 |
+
two image filters: blur filter and pencil filter. Similar as (Jerfel et al., 2019), for each image in PlainMulti, one artistic filters are applied to simulate a changing distribution of few-shot classification
|
565 |
+
tasks. After applying the filters, the total number of subdatasets is 12 and each tasks is sampled from
|
566 |
+
one of them. This data is named as Art-Multi. More detailed descriptions of the effect of different
|
567 |
+
filters is discussed in Appendix C.
|
568 |
+
|
569 |
+
Following the traditional meta-learning settings, all datasets are divided into meta-training, metavalidation and meta-testing classes. The traditional N-way K-shot settings are used to split training and
|
570 |
+
test set for each task. We adopt the standard four-block convolutional layers as the base learner (Finn
|
571 |
+
et al., 2017; Snell et al., 2017). The number of vertices of meta-knowledge graph for Plain-Multi
|
572 |
+
and Art-Multi datasets are set as 4 and 8, respectively. Additionally, for the miniImagenet, similar
|
573 |
+
as (Finn et al., 2018), which tasks are constructed from a single domain and do not have heterogeneity,
|
574 |
+
we compare our proposed ARML with other baselines and present the results in Appendix D.
|
575 |
+
|
576 |
+
5.3.1 PERFORMANCE VALIDATION
|
577 |
+
|
578 |
+
**Overall Qualitative Analyses** Experimental results for Plain-Multi and Art-Multi are shown in
|
579 |
+
Table 1 and Table 2, respectively. For each dataset, the performance accuracy with 95% confidence
|
580 |
+
interval are reported. Note that, due to the space limitation, in Art-Multi dataset, we only show
|
581 |
+
the average value of each filter and the full results table are shown in Table 9 of Appendix E. In
|
582 |
+
these two tables, first, we can observe that task-specific models (MT-Net, MUMOMAML, HSML,
|
583 |
+
TADAM) significantly outperforms globally shared models (MAML, Meta-SGD, ProtoNet) in both
|
584 |
+
gradient-based and non-parametric meta-learning research lines. Second, compared ARML with
|
585 |
+
other task-specific gradient-based meta-learning methods, the better performance confirms that
|
586 |
+
ARML can model and extract task-specific information more accurately by leveraging the constructed
|
587 |
+
meta-knowledge graph. Especially, the performance gap between the ARML and HSML verifies the
|
588 |
+
benefits of relational structure compared with isolated clustering structure. Finally, as a gradient-based
|
589 |
+
meta-learning algorithm, ARML can also outperform ProtoNet and TADAM, two representative
|
590 |
+
non-parametric meta-learning algorithms.
|
591 |
+
|
592 |
+
Table 1: Overall few-shot classification results (accuracy ± 95% confidence) on Plain-Multi dataset.
|
593 |
+
|
594 |
+
|Settings|Algorithms|Data: Bird Data: Texture Data: Aircraft Data: Fungi|
|
595 |
+
|---|---|---|
|
596 |
+
|
597 |
+
|MAML 53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% MetaSGD 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% MT-Net 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 5-way MUMOMAML 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 1-shot HSML 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39% ProtoNet 54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% TADAM 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33% ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|MAML MetaSGD MT-Net MUMOMAML HSML|53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39%|
|
598 |
+
|---|---|---|
|
599 |
+
||ProtoNet TADAM|54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33%|
|
600 |
+
||ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
|
601 |
+
|
602 |
+
|ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
|
603 |
+
|---|---|---|
|
604 |
+
|MAML 68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% MetaSGD 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% MT-Net 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 5-way MUMOMAML 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 5-shot HSML 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80% ProtoNet 68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% TADAM 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82% ARML 73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|MAML MetaSGD MT-Net MUMOMAML HSML|68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80%|
|
605 |
+
||ProtoNet TADAM|68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82%|
|
606 |
+
||ARML|73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|
|
607 |
+
|
608 |
+
|
609 |
+
-----
|
610 |
+
|
611 |
+
Table 2: Overall few-shot classification results (accuracy ± 95% confidence) on Art-Multi dataset.
|
612 |
+
|
613 |
+
|Settings|Algorithms|Avg. Origninal Avg. Blur Avg. Pencil|
|
614 |
+
|---|---|---|
|
615 |
+
|
616 |
+
|
617 |
+
|MAML 42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% MetaSGD 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% MT-Net 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 5-way, 1-shot MUMOMAML 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% HSML 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36% Protonet 42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% TADAM 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34% ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|MAML MetaSGD MT-Net MUMOMAML HSML|42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36%|
|
618 |
+
|---|---|---|
|
619 |
+
||Protonet TADAM|42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34%|
|
620 |
+
||ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
|
621 |
+
|
622 |
+
|
623 |
+
|ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
|
624 |
+
|---|---|---|
|
625 |
+
|MAML 58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% MetaSGD 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% MT-Net 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 5-way, 5-shot MUMOMAML 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% HSML 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72% Protonet 58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% TADAM 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74% ARML 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|MAML MetaSGD MT-Net MUMOMAML HSML|58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72%|
|
626 |
+
||Protonet TADAM|58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74%|
|
627 |
+
||ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
|
628 |
+
|
629 |
+
|
630 |
+
|
631 |
+
**Model Ablation Study** In this section, we perform the ablation study of the proposed ARML to
|
632 |
+
demonstrate the effectiveness of each component. The results of ablation study on 5-way, 5-shot
|
633 |
+
scenario for Art-Multi dataset are presented in Table 3. In Appendix F, we also show the full results
|
634 |
+
for Art-Multi in Table 6 and the ablation study of Plain-Multi in Table 7. Specifically, to show
|
635 |
+
the effectiveness of prototype construction, in ablation I, we use the mean pooling aggregation
|
636 |
+
of each sample rather than the prototype-based relational graph to interact with meta-knowledge
|
637 |
+
graph. In ablation II, we use all samples to construct the sample-level relational graph without
|
638 |
+
using the prototype. Compared with ablation I and II, the better performance of ARML shows
|
639 |
+
that structuring samples can (1) better handling the underlying relations (2) alleviating the effect of
|
640 |
+
potential anomalies by structuring samples as prototypes.
|
641 |
+
|
642 |
+
In ablation III, we remove the meta-knowledge graph and use the prototype-based relational graph
|
643 |
+
structure with aggregator AG[q] as the task representation. The better performance of ARML demonstrates the effectiveness of meta-knowledge graph for capturing the relational structure and facilitating
|
644 |
+
the classification performance. We further remove the reconstruction loss and show the results in
|
645 |
+
ablation IV and the results demonstrate that the autoencoder structure can benefit the process of
|
646 |
+
learning the representation.
|
647 |
+
|
648 |
+
In ablation VI and VII, we change the modulate function to film (Perez et al., 2018) and tanh,
|
649 |
+
respectively. We can see that ARML is not very sensitive to the modulating function, and sigmoid
|
650 |
+
function is slightly better than other activation functions in most cases.
|
651 |
+
|
652 |
+
Table 3: Results (accuracy ± 95% confidence) of Ablation Models (5-way, 5-shot) on Art-Multi.
|
653 |
+
|
654 |
+
|Ablation Models|Ave. Original Ave. Blur Ave. Pencil|
|
655 |
+
|---|---|
|
656 |
+
|
657 |
+
|I. no prototype-based graph II. no prototype|60.80 ± 0.74% 58.36 ± 0.73% 54.79 ± 0.73% 61.34 ± 0.73% 58.34 ± 0.74% 54.81 ± 0.73%|
|
658 |
+
|---|---|
|
659 |
+
|
660 |
+
|III. no meta-knowledge graph IV. no reconstruction loss|59.99 ± 0.75% 57.79 ± 0.73% 53.68 ± 0.74% 59.07 ± 0.73% 57.20 ± 0.74% 52.45 ± 0.73%|
|
661 |
+
|---|---|
|
662 |
+
|
663 |
+
|V. tanh modulation VI. film modulation|62.34 ± 0.74% 58.58 ± 0.75% 54.01 ± 0.74% 60.06 ± 0.75% 57.47 ± 0.73% 52.06 ± 0.74%|
|
664 |
+
|---|---|
|
665 |
+
|
666 |
+
|ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
|
667 |
+
|---|---|
|
668 |
+
|
669 |
+
|
670 |
+
5.3.2 ANALYSIS OF CONSTRUCTED META-KNOWLEDGE GRAPH
|
671 |
+
|
672 |
+
In this section, we conduct extensive analysis for the constructed meta-knowledge graph, which is
|
673 |
+
regarded as the key component in ARML. Due to the space limit, we only present the results on ArtMulti datasets. For Plain-Multi, the analysis with similar observations are discussed in Appendix G.
|
674 |
+
|
675 |
+
|
676 |
+
-----
|
677 |
+
|
678 |
+
**Performance v.s. Vertice Numbers** We first investigate the impact of vertice numbers in metaknowledge graph. The results are shown in Table 4. From the results, we can notice that the
|
679 |
+
performance saturates as the number of vertices researches around 8. One potential reason is that 8
|
680 |
+
vertices is enough to capture the potential relations. If we have a larger datasets with more complex
|
681 |
+
relations, more vertices may be needed. In addition, if the meta-knowledge graph do not have enough
|
682 |
+
vertices, the worse performance suggests that the graph may not be able to capture enough relations
|
683 |
+
across tasks.
|
684 |
+
|
685 |
+
Table 4: Sensitivity analysis with different # of vertices in meta-knowledge graph (5-way, 5-shot).
|
686 |
+
|
687 |
+
|# of vertices|Ave. Original Ave. Blur Ave. Pencil|
|
688 |
+
|---|---|
|
689 |
+
|
690 |
+
|
691 |
+
|4 8 12 16 20|61.18 ± 0.72% 58.13 ± 0.73% 54.88 ± 0.75% 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73% 61.66 ± 0.73% 58.61 ± 0.72% 55.07 ± 0.74% 61.75 ± 0.73% 58.67 ± 0.74% 55.26 ± 0.73% 61.91 ± 0.74% 58.92 ± 0.73% 55.24 ± 0.72%|
|
692 |
+
|---|---|
|
693 |
+
|
694 |
+
|
695 |
+
|
696 |
+
**Model Interpretation Analysis of Meta-Knowledge Graph** We then analyze the learned metaknowledge graph. For each subdataset, we randomly select one task as exemplary. For each task,
|
697 |
+
in the left part of Figure 3 we show the similarity heatmap between prototypes and vertices in
|
698 |
+
meta-knowledge graph, where deeper color means higher similarity. V0-V8 and P1-P5 denotes
|
699 |
+
the different vertices and prototypes, respectively. The meta-knowledge graph is also illustrated
|
700 |
+
in the right part. Similar as the graph in 2D regression, we set a threshold to filter links with low
|
701 |
+
similarity and illustrate the rest of them. First, We can see that the V1 is mainly activated by bird
|
702 |
+
and aircraft (including all filters), which may reflect the shape similarity between bird and aircraft.
|
703 |
+
Second, V2, V3, V4 are firstly activated by texture and they form a loop in the meta-knowledge
|
704 |
+
graph. Especially, V2 also benefits images with blur and pencil filters. Thus, V2 may represent the
|
705 |
+
main texture and facilitate the training process on other subdatasets. The meta-knowledge graph also
|
706 |
+
shows the importance of V2 since it is connected with almost all other vertices. Third, when we use
|
707 |
+
blur filter, in most cases (bird blur, texture blur, fungi blur), V7 is activated. Thus, V7 may show the
|
708 |
+
similarity of images with blur filter. In addition, the connection between V7 and V2 and V3 show that
|
709 |
+
classify blur images may depend on the texture information. Fourth, V6 (activated by aircraft mostly)
|
710 |
+
connects with V2 and V3, justifying the importance of texture information to classify the aircrafts.
|
711 |
+
|
712 |
+
V1
|
713 |
+
|
714 |
+
V2
|
715 |
+
|
716 |
+
Bird Texture Aircraft Fungi
|
717 |
+
|
718 |
+
V0 V3
|
719 |
+
|
720 |
+
Bird Blur Texture Blur Aircraft Blur Fungi Blur V7
|
721 |
+
|
722 |
+
V4
|
723 |
+
|
724 |
+
V6
|
725 |
+
|
726 |
+
V5
|
727 |
+
|
728 |
+
Bird Pencil Texture Pencil Aircraft Pencil Fungi Pencil
|
729 |
+
|
730 |
+
|
731 |
+
Figure 3: Interpretation of meta-knowledge graph on Art-Multi dataset. For each subdataset, we
|
732 |
+
randomly select one task from them. In the left, we show the similarity heatmap between prototypes
|
733 |
+
(P0-P5) and meta-knowledge vertices (V0-V7). In the right part, we show the meta-knowledge graph.
|
734 |
+
|
735 |
+
6 CONCLUSION
|
736 |
+
|
737 |
+
In this paper, to improve the effectiveness of meta-learning for handling heterogeneous task, we
|
738 |
+
propose a new framework called ARML, which automatically extract relation across tasks and
|
739 |
+
construct a meta-knowledge graph. When a new task comes in, it can quickly find the most relevant
|
740 |
+
relations through the meta-knowledge graph and use this knowledge to facilitate its training process.
|
741 |
+
The experiments demonstrate the effectiveness of our proposed algorithm.
|
742 |
+
|
743 |
+
|
744 |
+
-----
|
745 |
+
|
746 |
+
REFERENCES
|
747 |
+
|
748 |
+
Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul,
|
749 |
+
Brendan Shillingford, and Nando De Freitas. Learning to learn by gradient descent by gradient
|
750 |
+
descent. In NeurIPS, pp. 3981–3989, 2016.
|
751 |
+
|
752 |
+
Chelsea Finn and Sergey Levine. Meta-learning and universality: Deep representations and gradient
|
753 |
+
descent can approximate any learning algorithm. In ICLR, 2018.
|
754 |
+
|
755 |
+
Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of
|
756 |
+
deep networks. In ICML, pp. 1126–1135, 2017.
|
757 |
+
|
758 |
+
Chelsea Finn, Kelvin Xu, and Sergey Levine. Probabilistic model-agnostic meta-learning. In NeurIPS,
|
759 |
+
2018.
|
760 |
+
|
761 |
+
Sebastian Flennerhag, Pablo G Moreno, Neil D Lawrence, and Andreas Damianou. Transferring
|
762 |
+
knowledge across learning processes. ICLR, 2019.
|
763 |
+
|
764 |
+
Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural
|
765 |
+
message passing for quantum chemistry. In ICML, pp. 1263–1272. JMLR. org, 2017.
|
766 |
+
|
767 |
+
Jonathan Gordon, John Bronskill, Matthias Bauer, Sebastian Nowozin, and Richard E Turner. Metalearning probabilistic inference for prediction. In ICLR, 2019.
|
768 |
+
|
769 |
+
Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, and Thomas Griffiths. Recasting gradientbased meta-learning as hierarchical bayes. In ICLR, 2018.
|
770 |
+
|
771 |
+
Jiatao Gu, Yong Wang, Yun Chen, Kyunghyun Cho, and Victor OK Li. Meta-learning for low-resource
|
772 |
+
neural machine translation. In EMNLP, 2018.
|
773 |
+
|
774 |
+
Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In
|
775 |
+
_NeurIPS, pp. 1024–1034, 2017._
|
776 |
+
|
777 |
+
Ghassen Jerfel, Erin Grant, Thomas L Griffiths, and Katherine Heller. Reconciling meta-learning and
|
778 |
+
continual learning with online mixtures of tasks. NeurIPS, 2019.
|
779 |
+
|
780 |
+
Bingyi Kang, Zhuang Liu, Xin Wang, Fisher Yu, Jiashi Feng, and Trevor Darrell. Few-shot object
|
781 |
+
detection via feature reweighting. In ICCV, 2019.
|
782 |
+
|
783 |
+
Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks.
|
784 |
+
In ICLR, 2017.
|
785 |
+
|
786 |
+
Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and
|
787 |
+
subspace. In ICML, pp. 2933–2942, 2018.
|
788 |
+
|
789 |
+
Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-sgd: Learning to learn quickly for few
|
790 |
+
shot learning. arXiv preprint arXiv:1707.09835, 2017.
|
791 |
+
|
792 |
+
Zhaojiang Lin, Andrea Madotto, Chien-Sheng Wu, and Pascale Fung. Personalizing dialogue agents
|
793 |
+
via meta-learning. 2019.
|
794 |
+
|
795 |
+
Ming-Yu Liu, Xun Huang, Arun Mallya, Tero Karras, Timo Aila, Jaakko Lehtinen, and Jan Kautz.
|
796 |
+
Few-shot unsupervised image-to-image translation. arXiv preprint arXiv:1905.01723, 2019.
|
797 |
+
|
798 |
+
Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive metalearner. ICLR, 2018.
|
799 |
+
|
800 |
+
Alex Nichol and John Schulman. Reptile: a scalable metalearning algorithm. arXiv preprint
|
801 |
+
_arXiv:1803.02999, 2018._
|
802 |
+
|
803 |
+
Boris Oreshkin, Pau Rodr´ıguez Lopez, and Alexandre Lacoste. Tadam: Task dependent adaptive´
|
804 |
+
metric for improved few-shot learning. In NeurIPS, pp. 721–731, 2018.
|
805 |
+
|
806 |
+
Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, and Aaron C. Courville. Film: Visual
|
807 |
+
reasoning with a general conditioning layer. In AAAI, 2018.
|
808 |
+
|
809 |
+
|
810 |
+
-----
|
811 |
+
|
812 |
+
Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. ICLR, 2016.
|
813 |
+
|
814 |
+
Andrei A Rusu, Dushyant Rao, Jakub Sygnowski, Oriol Vinyals, Razvan Pascanu, Simon Osindero,
|
815 |
+
and Raia Hadsell. Meta-learning with latent embedding optimization. In ICLR, 2019.
|
816 |
+
|
817 |
+
Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In
|
818 |
+
_NeurIPS, pp. 4077–4087, 2017._
|
819 |
+
|
820 |
+
Petar Velickoviˇ c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua´
|
821 |
+
Bengio. Graph attention networks. In ICLR, 2018.
|
822 |
+
|
823 |
+
Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al. Matching networks for one
|
824 |
+
shot learning. In NeurIPS, pp. 3630–3638, 2016.
|
825 |
+
|
826 |
+
Risto Vuorio, Shao-Hua Sun, Hexiang Hu, and Joseph J Lim. Toward multimodal model-agnostic
|
827 |
+
meta-learning. NeurIPS, 2019.
|
828 |
+
|
829 |
+
Xin Wang, Fisher Yu, Ruth Wang, Trevor Darrell, and Joseph E Gonzalez. Tafe-net: Task-aware
|
830 |
+
feature embeddings for low shot learning. In CVPR, pp. 1831–1840, 2019.
|
831 |
+
|
832 |
+
Flood Sung Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales.
|
833 |
+
Learning to compare: Relation network for few-shot learning. In CVPR, 2018.
|
834 |
+
|
835 |
+
Huaxiu Yao, Ying Wei, Junzhou Huang, and Zhenhui Li. Hierarchically structured meta-learning. In
|
836 |
+
_ICML, pp. 7045–7054, 2019._
|
837 |
+
|
838 |
+
Jaesik Yoon, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn.
|
839 |
+
Bayesian model-agnostic meta-learning. In NeurIPS, pp. 7343–7353, 2018.
|
840 |
+
|
841 |
+
Sung Whan Yoon, Jun Seo, and Jaekyun Moon. Tapnet: Neural network augmented with task-adaptive
|
842 |
+
projection for few-shot learning. In ICML, 2019.
|
843 |
+
|
844 |
+
|
845 |
+
-----
|
846 |
+
|
847 |
+
A ALGORITHM IN META-TESTING PROCESS
|
848 |
+
|
849 |
+
**Algorithm 2 Meta-Testing Process of ARML**
|
850 |
+
|
851 |
+
**Require: Training data** _t_ [of a new task][ T][t]
|
852 |
+
_D[tr]_
|
853 |
+
|
854 |
+
1: Construct the prototype-based relational graph Rt by computing prototype in equation 2 and
|
855 |
+
weight in equation 4
|
856 |
+
|
857 |
+
2: Compute the similarity between each prototype and meta-knowledge vertice in equation 6 and
|
858 |
+
construct the super-graph St
|
859 |
+
|
860 |
+
3: Apply GNN on super-graph St and get the updated prototype representation **C[ˆ]** _Rt_
|
861 |
+
|
862 |
+
4: Aggregate CRt in equation 8, **C[ˆ]** _Rt in equation 9 and get the representations qt, tt_
|
863 |
+
|
864 |
+
5: Compute the task-specific initialization θ0t in equation 10
|
865 |
+
6: Update parameters θt = θ0t − _α∇θL(fθ, Dt[tr][)]_
|
866 |
+
|
867 |
+
|
868 |
+
B HYPERPARAMETERS SETTINGS
|
869 |
+
|
870 |
+
B.1 2D REGRESSION
|
871 |
+
|
872 |
+
In 2D regression problem, we set the inner-loop stepsize (i.e., α) and outer-loop stepsize (i.e., β) as
|
873 |
+
0.001 and 0.001, respectively. The embedding function E is set as one layer with 40 neurons. The
|
874 |
+
autoencoder aggregator is constructed by the gated recurrent structures. We set the meta-batch size as
|
875 |
+
25 and the inner loop gradient steps as 5.
|
876 |
+
|
877 |
+
B.2 FEW-SHOT IMAGE CLASSIFICATION
|
878 |
+
|
879 |
+
In few-shot image classification, for both Plain-Multi and Art-Multi datasets, we set the corresponding
|
880 |
+
inner stepsize (i.e., α) as 0.001 and the outer stepsize (i.e., β) as 0.01. For the embedding function E,
|
881 |
+
we employ two convolutional layers with 3 × 3 filters. The channel size of these two convolutional
|
882 |
+
layers are 32. After convolutional layers, we use two fully connected layers with 384 and 64 neurons
|
883 |
+
for each layer. Similar as the hyperparameter settings in 2D regression, the autoencoder aggregator
|
884 |
+
is constructed by the gated recurrent structures, i.e., AG[t], AG[t]dec [AG][q][,][ AG]dec[q] [are all GRUs. The]
|
885 |
+
meta-batch size is set as 4. For the inner loop, we use 5 gradient steps.
|
886 |
+
|
887 |
+
B.3 DETAILED BASELINE SETTINGS
|
888 |
+
|
889 |
+
For the gradient-based baselines (i.e., MAML, MetaSGD, MT-Net, BMAML. MUMOMAML,
|
890 |
+
HSML), we use the same inner loop stepsize and outer loop stepsize rate as our ARML. As for
|
891 |
+
non-parametric based meta-learning algorithms, both TADAM and Prototypical network, we use the
|
892 |
+
same meta-training and meta-testing process as gradient-based models. Additionally, TADAM uses
|
893 |
+
the same embedding function E as ARML for fair comparison (i.e., similar expressive ability).
|
894 |
+
|
895 |
+
C ADDITIONAL DISCUSSION OF DATASETS
|
896 |
+
|
897 |
+
In this dataset, we use pencil and blur filers to change the task distribution. To investigate the effect
|
898 |
+
of pencil and blur filters, we provide one example in Figure 4. We can observe that different filters
|
899 |
+
result in different data distributions. All used filter are provided by OpenCV[1].
|
900 |
+
|
901 |
+
D RESULTS ON MINIIMAGENET
|
902 |
+
|
903 |
+
For miniimagenet, since it do not have the characteristic of task heterogeneity, we show the results in
|
904 |
+
Table 5. In this table, we compare the MiniImagenet dataset with other gradient-based meta-learning
|
905 |
+
models (the first four baselines are globally shared models and the next four are task-specific models).
|
906 |
+
Similar as (Finn et al., 2018), we also apply the standard 4-block convolutional layers for each
|
907 |
+
|
908 |
+
1https://opencv.org/
|
909 |
+
|
910 |
+
|
911 |
+
-----
|
912 |
+
|
913 |
+
(a) : Plain Image (b) : with blur filter (c) : with pencil filter
|
914 |
+
|
915 |
+
Figure 4: Effect of different filters.
|
916 |
+
|
917 |
+
baseline. For MT-Net, we use the reported results in (Yao et al., 2019), which control the model with
|
918 |
+
the same expressive power. The results indicate that our proposed ARML can outperform the original
|
919 |
+
MAML and achieves comparable performance with task-specific models (e.g., MT-Net, PLATIPUS,
|
920 |
+
HSML). Most task-specific models achieve the similar performance on the standard benchmark due
|
921 |
+
to the homogeneity between tasks.
|
922 |
+
|
923 |
+
Table 5: Performance comparison on the 5-way, 1-shot MiniImagenet dataset.
|
924 |
+
|
925 |
+
|Algorithms|5-way 1-shot Accuracy|
|
926 |
+
|---|---|
|
927 |
+
|
928 |
+
|MAML (Finn et al., 2017) LLAMA (Finn & Levine, 2018) Reptile (Nichol & Schulman, 2018) MetaSGD (Li et al., 2017)|48.70 1.84% ± 49.40 1.83% ± 49.97 0.32% ± 50.47 1.87% ±|
|
929 |
+
|---|---|
|
930 |
+
|
931 |
+
|MT-Net (Lee & Choi, 2018) MUMOMAML (Vuorio et al., 2019) HSML (Yao et al., 2019) PLATIPUS (Finn et al., 2018)|49.75 1.83% ± 49.86 1.85% ± 50.38 1.85% ± 50.13 1.86% ±|
|
932 |
+
|---|---|
|
933 |
+
|
934 |
+
|ARML|50.42 1.73% ±|
|
935 |
+
|---|---|
|
936 |
+
|
937 |
+
|
938 |
+
E ADDITIONAL RESULTS OF FEW-SHOT IMAGE CLASSIFICATION
|
939 |
+
|
940 |
+
E.1 FULL OVERALL RESULTS TABLE OF ART-MULTI DATASET
|
941 |
+
|
942 |
+
We provide the full results table of Art-Multi Dataset in Table 9. In this table, we can see our proposed
|
943 |
+
ARML outperforms almost all baselines in every sub-datasets.
|
944 |
+
|
945 |
+
F FURTHER INVESTIGATION OF ABLATION STUDY
|
946 |
+
|
947 |
+
In this section, we first show the full evaluation results of model ablation study on Art-Multi dataset
|
948 |
+
in 6. Note that, for the tanh activation (ablation model V), the performance is similar as applying
|
949 |
+
the sigmoid activation. On some subdatasets, the results are even better. We choose the sigmoid
|
950 |
+
activation for ARML because it achieves overall better performance than the tanh activation on more
|
951 |
+
subdatasets. Then, for Plain-Multi dataset, we show the results in 7. The conclusion of ablation study
|
952 |
+
in Plain-Multi dataset is similar as the conclusion drawn from the results on Art-Multi dataset. The
|
953 |
+
improvement on these two datasets verifies the necessity of the joint framework in ARML.
|
954 |
+
|
955 |
+
G ADDITIONAL ANALYSIS OF META-KNOWLEDGE GRAPH
|
956 |
+
|
957 |
+
In this section, we add more interpretation analysis of meta-knowledge graph. First, we show the full
|
958 |
+
evaluation results of sensitivity analysis on Art-Multi dataset in Table 8.
|
959 |
+
|
960 |
+
|
961 |
+
-----
|
962 |
+
|
963 |
+
Table 6: Full evaluation results of model ablation study on Art-Multi dataset. B, T, A, F represent
|
964 |
+
bird, texture, aircraft, fungi, respectively. Plain means original image.
|
965 |
+
|
966 |
+
|Model|B Plain B Blur B Pencil T Plain T Blur T Pencil|
|
967 |
+
|---|---|
|
968 |
+
|
969 |
+
|
970 |
+
|I. no prototype-based graph II. no prototype|72.08% 71.06% 66.83% 45.23% 39.97% 41.67% 72.99% 70.92% 67.19% 45.17% 40.05% 41.04%|
|
971 |
+
|---|---|
|
972 |
+
|
973 |
+
|
974 |
+
|III. no meta-knowledge graph IV. no reconstruction loss|70.79% 69.53% 64.87% 43.37% 39.86% 41.23% 70.82% 69.87% 65.32% 44.02% 40.18% 40.52%|
|
975 |
+
|---|---|
|
976 |
+
|
977 |
+
|
978 |
+
|V. tanh VI. film|72.70% 69.53% 66.85% 45.81% 40.79% 38.64% 71.52% 68.70% 64.23% 43.83% 40.52% 39.49%|
|
979 |
+
|---|---|
|
980 |
+
|
981 |
+
|
982 |
+
|Model|A Plain A Blur A Pencil F Plain F Blur F Pencil|
|
983 |
+
|---|---|
|
984 |
+
|
985 |
+
|
986 |
+
|I. no prototype-based graph II. no prototype|70.06% 68.02% 60.66% 55.81% 54.39% 50.01% 71.10% 67.59% 61.07% 56.11% 54.82% 49.95%|
|
987 |
+
|---|---|
|
988 |
+
|
989 |
+
|
990 |
+
|III. no meta-knowledge graph IV. no reconstruction loss|69.97% 68.03% 59.72% 55.84% 53.72% 48.91% 66.83% 65.73% 55.98% 54.62% 53.02% 48.01%|
|
991 |
+
|---|---|
|
992 |
+
|
993 |
+
|
994 |
+
|V. tanh VI. film|73.96% 69.70% 60.75% 56.87% 54.30% 49.82% 69.13% 66.93% 55.59% 55.77% 53.72% 48.92%|
|
995 |
+
|---|---|
|
996 |
+
|
997 |
+
|
998 |
+
|ARML|71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
|
999 |
+
|---|---|
|
1000 |
+
|
1001 |
+
|
1002 |
+
ARML **73.05%** **71.31%** **67.14%** 45.32% 40.15% **41.98%**
|
1003 |
+
|
1004 |
+
|
1005 |
+
Table 7: Results of Model Ablation (5-way, 5-shot results) on Plain-Multi dataset.
|
1006 |
+
|
1007 |
+
|Ablation Models|Bird|Texture|Aircraft|Fungi|
|
1008 |
+
|---|---|---|---|---|
|
1009 |
+
|
1010 |
+
|I. no sample-level graph II. no prototype|71.96 ± 0.72% 72.86 ± 0.74%|48.79 ± 0.67% 49.03 ± 0.69%|74.02 ± 0.65% 74.36 ± 0.65%|56.83 ± 0.80% 57.02 ± 0.81%|
|
1011 |
+
|---|---|---|---|---|
|
1012 |
+
|
1013 |
+
|III. no meta-knowledge graph IV. no reconstruction loss|71.23 ± 0.75% 70.99 ± 0.74%|47.96 ± 0.68% 48.03 ± 0.69%|73.71 ± 0.69% 69.86 ± 0.66%|55.97 ± 0.82% 55.78 ± 0.83%|
|
1014 |
+
|---|---|---|---|---|
|
1015 |
+
|
1016 |
+
|V. tanh VI. film|73.45 ± 0.71% 72.95 ± 0.73%|49.23 ± 0.66% 49.18 ± 0.69%|74.39 ± 0.65% 73.82 ± 0.68%|57.38 ± 0.80% 56.89 ± 0.80%|
|
1017 |
+
|---|---|---|---|---|
|
1018 |
+
|
1019 |
+
|ARML|73.34 ± 0.70%|49.67 ± 0.67%|74.88 ± 0.64%|57.55 ± 0.82%|
|
1020 |
+
|---|---|---|---|---|
|
1021 |
+
|
1022 |
+
|
1023 |
+
Then, we analyze the meta-knowledge graph on Plain-Multi dataset by visualizing the learned metaknowledge graph on Plain-Multi dataset (as shown in Figure 5). In this figure, we can see that
|
1024 |
+
different subdatasets activate different vertices. Specifically, V2, which is mainly activated by texture,
|
1025 |
+
plays a significantly important role in aircraft and fungi. Thus, V2 connects with V3 and V1 in the
|
1026 |
+
meta-knowledge graph, which are mainly activated by fungi and aircraft, respectively. In addition,
|
1027 |
+
V0 is also activated by aircraft because of the similar contour between aircraft and bird. Furthermore,
|
1028 |
+
in meta-knowledge graph, V0 connects with V3, which shows the similarity of environment between
|
1029 |
+
bird images and fungi images.
|
1030 |
+
|
1031 |
+
|
1032 |
+
-----
|
1033 |
+
|
1034 |
+
Bird
|
1035 |
+
|
1036 |
+
|
1037 |
+
Texture
|
1038 |
+
|
1039 |
+
|
1040 |
+
V1
|
1041 |
+
|
1042 |
+
V2
|
1043 |
+
|
1044 |
+
V0
|
1045 |
+
|
1046 |
+
V3
|
1047 |
+
|
1048 |
+
|
1049 |
+
Aircraft Fungi
|
1050 |
+
|
1051 |
+
Figure 5: Interpretation of meta-knowledge graph on Plain-Multi dataset. For each subdataset, one
|
1052 |
+
task is randomly selected from them. In the left figure, we show the similarity heatmap between
|
1053 |
+
prototypes (P1-P5) and meta-knowledge vertices (denoted as E1-E4), where deeper color means
|
1054 |
+
higher similarity. In the right part, we show the meta-knowledge graph, where a threshold is also set
|
1055 |
+
to filter low similarity links.
|
1056 |
+
|
1057 |
+
Table 8: Full evaluation results of performance v.s. # vertices of meta-knowledge graph on Art-Multi.
|
1058 |
+
B, T, A, F represent bird, texture, aircraft, fungi, respectively. Plain means original image.
|
1059 |
+
|
1060 |
+
|# of Vertices|B Plain B Blur B Pencil T Plain T Blur T Pencil|
|
1061 |
+
|---|---|
|
1062 |
+
|
1063 |
+
|# of Vertices|A Plain A Blur A Pencil F Plain F Blur F Pencil|
|
1064 |
+
|---|---|
|
1065 |
+
|
1066 |
+
|4 8 12 16 20|70.98% 67.36% 60.46% 56.07% 53.77% 50.08% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53% 71.78% 67.26% 60.97% 56.87% 55.14% 50.86% 71.96% 68.55% 61.14% 56.76% 54.54% 49.41% 72.02% 68.29% 60.59% 55.95% 54.53% 50.13%|
|
1067 |
+
|---|---|
|
1068 |
+
|
1069 |
+
|
1070 |
+
4 72.29% 70.36% 67.88% 45.37% 41.05% 41.43%
|
1071 |
+
8 73.05% 71.31% 67.14% 45.32% 40.15% 41.98%
|
1072 |
+
12 73.45% 70.64% 67.41% 44.53% 41.41% 41.05%
|
1073 |
+
16 72.68% 70.18% 68.34% 45.63% 41.43% 42.18%
|
1074 |
+
20 73.41% 71.07% 68.64% 46.26% 41.80% 41.61%
|
1075 |
+
|
1076 |
+
|
1077 |
+
-----
|
1078 |
+
|
1079 |
+
|55.27% 52.62% 48.58% 30.57% 28.65% 28.39% 45.59% 42.24% 34.52% 39.37% 38.58% 35.38% 55.23% 53.08% 48.18% 29.28% 28.70% 28.38% 51.24% 47.29% 35.98% 41.08% 40.38% 36.30% 56.99% 54.21% 50.25% 32.13% 29.63% 29.23% 43.64% 40.08% 33.73% 43.02% 42.64% 37.96% 57.73% 53.18% 50.96% 31.88% 29.72% 29.90% 49.95% 43.36% 39.61% 42.97% 40.08% 36.52% 58.15% 53.20% 51.09% 32.01% 30.21% 30.17% 49.98% 45.79% 40.87% 42.58% 41.29% 37.01%|53.67% 50.98% 46.66% 31.37% 29.08% 28.48% 45.54% 43.94% 35.49% 37.71% 38.00% 34.36% 54.76% 52.18% 48.85% 32.03% 29.90% 30.82% 50.42% 47.59% 40.17% 41.73% 40.09% 36.27%|59.67% 54.89% 52.97% 32.31% 30.77% 31.51% 51.99% 47.92% 41.93% 44.69% 42.13% 38.36%|
|
1080 |
+
|---|---|---|
|
1081 |
+
|MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
|
1082 |
+
|
1083 |
+
|71.51% 68.65% 63.93% 42.96% 39.59% 38.87% 64.68% 62.54% 49.20% 54.08% 52.02% 46.39% 71.31% 68.73% 64.33% 41.89% 37.79% 37.91% 64.88% 63.36% 52.31% 53.18% 52.26% 46.43% 71.18% 69.29% 68.28% 43.23% 39.42% 39.20% 63.39% 58.29% 46.12% 54.01% 51.70% 47.02% 71.57% 70.50% 64.57% 44.57% 40.31% 40.07% 63.36% 61.55% 52.17% 54.89% 52.82% 47.79% 71.75% 69.31% 65.62% 44.68% 40.13% 41.33% 70.12% 67.63% 59.40% 55.97% 54.60% 49.40%|70.42% 67.90% 61.82% 44.78% 38.43% 38.40% 65.84% 63.41% 54.08% 51.45% 50.56% 46.33% 70.08% 69.05% 65.45% 44.93% 41.80% 40.18% 70.35% 68.56% 59.09% 56.04% 54.04% 47.85%|73.05% 71.31% 67.14% 45.32% 40.15% 41.98% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
|
1084 |
+
|---|---|---|
|
1085 |
+
|MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
|
1086 |
+
|
1087 |
+
|
1088 |
+
F Pencil
|
1089 |
+
|
1090 |
+
F Blur
|
1091 |
+
|
1092 |
+
F Plain
|
1093 |
+
A Pencil
|
1094 |
+
|
1095 |
+
A Blur
|
1096 |
+
|
1097 |
+
A Plain
|
1098 |
+
T Pencil
|
1099 |
+
|
1100 |
+
T Blur
|
1101 |
+
|
1102 |
+
T Plain
|
1103 |
+
B Pencil
|
1104 |
+
|
1105 |
+
B Blur
|
1106 |
+
|
1107 |
+
B Plain
|
1108 |
+
|
1109 |
+
Algorithms
|
1110 |
+
Settings
|
1111 |
+
|
1112 |
+
|
1113 |
+
%
|
1114 |
+
**36.38**
|
1115 |
+
|
1116 |
+
%
|
1117 |
+
**13.42**
|
1118 |
+
|
1119 |
+
%
|
1120 |
+
**69.44**
|
1121 |
+
|
1122 |
+
%
|
1123 |
+
**93.41**
|
1124 |
+
|
1125 |
+
%
|
1126 |
+
**92.47**
|
1127 |
+
|
1128 |
+
%
|
1129 |
+
**99.51**
|
1130 |
+
|
1131 |
+
%
|
1132 |
+
**51.31**
|
1133 |
+
|
1134 |
+
%
|
1135 |
+
**77.30**
|
1136 |
+
|
1137 |
+
%
|
1138 |
+
**31.32**
|
1139 |
+
|
1140 |
+
%
|
1141 |
+
**97.52**
|
1142 |
+
|
1143 |
+
%
|
1144 |
+
**89.54**
|
1145 |
+
|
1146 |
+
%
|
1147 |
+
**67.59**
|
1148 |
+
|
1149 |
+
ARML
|
1150 |
+
|
1151 |
+
|
1152 |
+
%
|
1153 |
+
**53.50**
|
1154 |
+
|
1155 |
+
%
|
1156 |
+
**87.54**
|
1157 |
+
|
1158 |
+
%
|
1159 |
+
**83.56**
|
1160 |
+
|
1161 |
+
%
|
1162 |
+
**41.61**
|
1163 |
+
|
1164 |
+
%
|
1165 |
+
**59.68**
|
1166 |
+
|
1167 |
+
%
|
1168 |
+
**89.71**
|
1169 |
+
|
1170 |
+
%
|
1171 |
+
**98.41**
|
1172 |
+
15%.40
|
1173 |
+
|
1174 |
+
%
|
1175 |
+
**32.45**
|
1176 |
+
|
1177 |
+
%
|
1178 |
+
**14.67**
|
1179 |
+
|
1180 |
+
%
|
1181 |
+
**31.71**
|
1182 |
+
|
1183 |
+
%
|
1184 |
+
**05.73**
|
1185 |
+
|
1186 |
+
ARML
|
1187 |
+
|
1188 |
+
|
1189 |
+
-----
|
1190 |
+
|
ai_scientist/fewshot_examples/2_carpe_diem.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"review": "{\n \"Summary\": \"This paper proposes Recency Bias, an adaptive mini batch selection method for training deep neural networks. To select informative minibatches for training, the proposed method maintains a fixed size sliding window of past model predictions for each data sample. At a given iteration, samples which have highly inconsistent predictions within the sliding window are added to the minibatch. The main contribution of this paper is the introduction of a sliding window to remember past model predictions, as an improvement over the SOTA approach: Active Bias, which maintains a growing window of model predictions. Empirical studies are performed to show the superiority of Recency Bias over two SOTA approaches. Results are shown on the task of (1) image classification from scratch and (2) image classification by fine-tuning pretrained networks.\",\n \"Strengths\": [\n \"The idea of using a sliding window over a growing window in active batch selection is interesting.\",\n \"Overall, the paper is well written. In particular, the Related Work section has a nice flow and puts the proposed method into context. Despite the method having limited novelty (sliding window instead of a growing window), the method has been well motivated by pointing out the limitations in SOTA methods.\",\n \"The results section is well structured. It's nice to see hyperparameter tuning results; and loss convergence graphs in various learning settings for each dataset.\"\n ],\n \"Weaknesses\": [\n \"The key concern about the paper is the lack of rigorous experimentation to study the usefulness of the proposed method. Despite the paper stating that there have been earlier work (Joseph et al, 2019 and Wang et al, 2019) that attempt mini-batch selection, the paper does not compare with them. This is limiting. Further, since the proposed method is not specific to the domain of images, evaluating it on tasks other than image classification, such as text classification for instance, would have helped validate its applicability across domains.\",\n \"Considering the limited results, a deeper analysis of the proposed method would have been nice. The idea of a sliding window over a growing window is a generic one, and there have been many efforts to theoretically analyze active learning over the last two decades. How does the proposed method fit in there? (For e.g., how does the expected model variance change in this setting?) Some form of theoretical/analytical reasoning behind the effectiveness of recency bias (which is missing) would provide greater insights to the community and facilitate further research in this direction.\",\n \"The claim of 20.5% reduction in test error mentioned in the abstract has not been clearly addressed and pointed out in the results section of the paper.\",\n \"The results would have been more complete if results were shown in a setting where just recency bias is used without the use of the selection pressure parameter. In other words, an ablation study on the effect of the selection pressure parameter would have been very useful.\",\n \"The intuition behind the method is described well, however, the proposed method would have been really solidified if it were analysed in the context of a simple machine learning problem (such as logistic regression). As an example, verifying if the chosen minibatch samples are actually close to the decision boundary of a model (even if the model is very simple) would have helped analyze the proposed method well.\"\n ],\n \"Originality\": 3,\n \"Quality\": 2,\n \"Clarity\": 4,\n \"Significance\": 2,\n \"Questions\": [\n \"How important is the warm-up phase to the proposed method? Considering the paper states that this is required to get good estimates of the quantization index of the samples, some ablation studies on reducing/increasing the warm-up phase and showing the results would have been useful to understand this.\",\n \"Fig 4: Why are there sharp dips periodically in all the graphs? What do these correspond to?\",\n \"The results are not conclusively in favor of the proposed method, and only is marginally better than the competitors. Why does online batch perform consistently than the proposed method? There is no discussion of these inferences from the results.\"\n ],\n \"Limitations\": [\n \"The primary concern is about the strength of the experimental results, which showed only a modest benefit on relatively simple datasets.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 2,\n \"Presentation\": 3,\n \"Contribution\": 2,\n \"Overall\": 4,\n \"Confidence\": 3,\n \"Decision\": \"Reject\"\n}"
|
3 |
+
}
|
ai_scientist/fewshot_examples/2_carpe_diem.pdf
ADDED
Binary file (858 kB). View file
|
|
ai_scientist/fewshot_examples/2_carpe_diem.txt
ADDED
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CARPE DIEM, SEIZE THE SAMPLES UNCERTAIN “AT
|
2 |
+
## THE MOMENT” FOR ADAPTIVE BATCH SELECTION
|
3 |
+
|
4 |
+
**Anonymous authors**
|
5 |
+
Paper under double-blind review
|
6 |
+
|
7 |
+
ABSTRACT
|
8 |
+
|
9 |
+
The performance of deep neural networks is significantly affected by how well
|
10 |
+
mini-batches are constructed. In this paper, we propose a novel adaptive batch
|
11 |
+
selection algorithm called Recency Bias that exploits the uncertain samples
|
12 |
+
predicted inconsistently in recent iterations. The historical label predictions of
|
13 |
+
each sample are used to evaluate its predictive uncertainty within a sliding window.
|
14 |
+
By taking advantage of this design, Recency Bias not only accelerates the training
|
15 |
+
step but also achieves a more accurate network. We demonstrate the superiority
|
16 |
+
of Recency Bias by extensive evaluation on two independent tasks. Compared with
|
17 |
+
existing batch selection methods, the results showed that Recency Bias reduced
|
18 |
+
the test error by up to 20.5% in a fixed wall-clock training time. At the same time,
|
19 |
+
it improved the training time by up to 59.3% to reach the same test error.
|
20 |
+
|
21 |
+
1 INTRODUCTION
|
22 |
+
|
23 |
+
Stochastic gradient descent (SGD) for randomly selected mini-batch samples is commonly used to
|
24 |
+
train deep network netowrks (DNNs). However, many recent studies have pointed out that the performance of DNNs is heavily dependent on how well the mini-batch samples are selected (Shrivastava
|
25 |
+
et al., 2016; Chang et al., 2017; Katharopoulos & Fleuret, 2018). In earlier approaches, a sample’s difficulty is employed to identify proper mini-batch samples, and these approaches achieve
|
26 |
+
a more accurate and robust network (Han et al., 2018) or expedite the training convergence of
|
27 |
+
SGD (Loshchilov & Hutter, 2016). However, the two opposing difficulty-based strategies, i.e., preferring easy samples (Kumar et al., 2010; Han et al., 2018) versus hard samples (Loshchilov & Hutter,
|
28 |
+
2016; Shrivastava et al., 2016), work well in different situations. Thus, for practical reasons to cover
|
29 |
+
more diverse situations, recent approaches begin to exploit a sample’s uncertainty that indicates the
|
30 |
+
consistency of previous predictions (Chang et al., 2017; Song et al., 2019).
|
31 |
+
|
32 |
+
An important question here is how to evaluate the sample’s uncertainty based on its historical
|
33 |
+
predictions during the training process. Intuitively, because a series of historical predictions can
|
34 |
+
be seen as a series of data indexed in chronological order, the uncertainty can be measured based on
|
35 |
+
_two forms of handling time-series observations: (i) a growing window (Figure 1(a)) that consistently_
|
36 |
+
increases the size of a window to use all available observations and (ii) a sliding window (Figure 1(b))
|
37 |
+
that maintains a window of a fixed size on the most recent observations by deleting outdated ones.
|
38 |
+
While the state-of-the-art algorithm, Active Bias (Chang et al., 2017), adopts the growing window,
|
39 |
+
we propose to use the sliding window in this paper.
|
40 |
+
|
41 |
+
|Historical observations|Col2|
|
42 |
+
|---|---|
|
43 |
+
|||
|
44 |
+
|||
|
45 |
+
|
46 |
+
|
47 |
+
Historical observations Historical observations
|
48 |
+
|
49 |
+
Growing Sliding
|
50 |
+
|
51 |
+
All available observations Outdated observations Recent observations
|
52 |
+
|
53 |
+
|
54 |
+
(a) Growing Window. (b) Sliding Window.
|
55 |
+
|
56 |
+
Figure 1: Two forms of handling the time-series observations.
|
57 |
+
|
58 |
+
In more detail, Active Bias recognizes uncertain samples based on the inconsistency of the predictions
|
59 |
+
in the entire history of past SGD iterations. Then, it emphasizes such uncertain samples by choosing
|
60 |
+
them with high probability for the next mini-batch. However, according to our experiments presented
|
61 |
+
|
62 |
+
|
63 |
+
-----
|
64 |
+
|
65 |
+
|… Horse Horse Horse|Col2|
|
66 |
+
|---|---|
|
67 |
+
|
68 |
+
|… Deer Deer Deer|Col2|
|
69 |
+
|---|---|
|
70 |
+
|
71 |
+
|
72 |
+
Images Inconsistent Predictions Consistent Predictions Sample Method
|
73 |
+
|
74 |
+
(Horse) History Uncertainty
|
75 |
+
|
76 |
+
Outdated Recent (too easy)
|
77 |
+
|
78 |
+
**High**
|
79 |
+
|
80 |
+
Horse Deer Horse Deer Deer Horse Deer … Horse Horse … Horse Active Bias
|
81 |
+
|
82 |
+
**Low**
|
83 |
+
|
84 |
+
Outdated Recent (too hard)
|
85 |
+
|
86 |
+
Deer Horse Horse Deer Horse Deer Horse … Deer Deer … Deer **High** **Recency Bias**
|
87 |
+
|
88 |
+
**Low**
|
89 |
+
|
90 |
+
Previous Training Iterations
|
91 |
+
|
92 |
+
|
93 |
+
Figure 2: The difference in sample uncertainty estimated by Active Bias and Recency Bias.
|
94 |
+
|
95 |
+
in Section 5.2, such uncertain samples slowed down the convergence speed of training, though they
|
96 |
+
ultimately reduced the generalization error. This weakness is attributed to the inherent limitation of
|
97 |
+
the growing window, where older observations could be too outdated (Torgo, 2011). In other words,
|
98 |
+
the outdated predictions no longer represent a network’s current behavior. As illustrated in Figure
|
99 |
+
2, when the label predictions of two samples were inconsistent for a long time, Active Bias invariably
|
100 |
+
regards them as highly uncertain, although their recent label predictions become consistent along
|
101 |
+
with the network’s training progress. This characteristic evidently entails the risk of emphasizing
|
102 |
+
uninformative samples that are too easy or too hard at the current moment, thereby slowing down
|
103 |
+
the convergence speed of training.
|
104 |
+
|
105 |
+
Therefore, we propose a simple but effective batch selection method, called Recency Bias, that takes
|
106 |
+
advantage of the sliding window to evaluate the uncertainty in fresher observations. As opposed to
|
107 |
+
_Active Bias, Recency Bias excludes the outdated predictions by managing a sliding window of a fixed_
|
108 |
+
size and picks up the samples predicted inconsistently within the sliding window. Thus, as shown
|
109 |
+
in Figure 2, the two samples uninformative at the moment are no longer selected by Recency Bias
|
110 |
+
simply because their recent predictions are consistent. Consequently, since informative samples are
|
111 |
+
effectively selected throughout the training process, this strategy not only accelerates the training
|
112 |
+
speed but also leads to a more accurate network.
|
113 |
+
|
114 |
+
To validate the superiority of Recency Bias, two popular convolutional neural networks (CNNs) were
|
115 |
+
trained for two independent tasks: image classification and fine tuning. We compared Recency Bias
|
116 |
+
with not only random batch selection (baseline) but also two state-of-the-art batch selection strategies.
|
117 |
+
Compared with three batch selection strategies, Recency Bias provided a relative reduction of test
|
118 |
+
error by 1.81%–20.5% in a fixed wall-clock training time. At the same time, it significantly reduced
|
119 |
+
the execution time by 24.6%–59.3% to reach the same test error.
|
120 |
+
|
121 |
+
2 RELATED WORK
|
122 |
+
|
123 |
+
Let D = {(xi, yi)|1 ≤ _i ≤_ _N_ _} be the entire training dataset composed of a sample xi with its_
|
124 |
+
true label yi, where N is the total number of training samples. Then, a straightforward strategy to
|
125 |
+
construct a mini-batch = (xi, yi) 1 _i_ _b_ is to select b samples uniformly at random (i.e.,
|
126 |
+
_M_ _{_ _|_ _≤_ _≤_ _}_
|
127 |
+
_P_ (xi ) = 1/N ) from the training dataset .
|
128 |
+
_|D_ _D_
|
129 |
+
|
130 |
+
Because not all samples have an equal impact on training, many research efforts have been devoted
|
131 |
+
to develop advanced sampling schemes. Bengio et al. (2009) first took easy samples and then
|
132 |
+
gradually increased the difficulty of samples using heuristic rules. Kumar et al. (2010) determined the
|
133 |
+
easiness of the samples using their prediction errors. Recently, Tsvetkov et al. (2016) used Bayesian
|
134 |
+
optimization to learn an optimal curriculum for training dense, distributed word representations.
|
135 |
+
Sachan & Xing (2016) emphasized that the right curriculum must introduce a small number of the
|
136 |
+
samples dissimilar to those previously seen. Fan et al. (2017) proposed a neural data filter based on
|
137 |
+
reinforcement learning to select training samples adaptively. However, it is common for deep learning
|
138 |
+
to emphasize hard samples because of the plethora of easy ones (Katharopoulos & Fleuret, 2018).
|
139 |
+
|
140 |
+
Loshchilov & Hutter (2016) proposed a difficulty-based sampling scheme, called Online Batch,
|
141 |
+
that uses the rank of the loss computed from previous epochs. Online Batch sorts the previously
|
142 |
+
computed losses of samples in descending order and exponentially decays the sampling probability
|
143 |
+
of a sample according to its rank r. Then, the r-th ranked sample x(r) is selected with the probability
|
144 |
+
dropping by a factor of exp log(se)/N, where se is the selection pressure parameter that affects
|
145 |
+
the probability gap between the most and the least important samples. When normalized to sum
|
146 |
+
to 1.0, the probability P (x( r) ; se) is defined by Eq. (1). It has been reported that _Online Batch_
|
147 |
+
_|D_
|
148 |
+
|
149 |
+
|
150 |
+
-----
|
151 |
+
|
152 |
+
accelerates the convergence of training but deteriorates the generalization error because of the
|
153 |
+
overfitting to hard training samples (Loshchilov & Hutter, 2016).
|
154 |
+
|
155 |
+
_r_
|
156 |
+
1/ exp log(se)/N
|
157 |
+
_P_ (x(r) ; se) = _N_ _j_ (1)
|
158 |
+
_|D_ _j=1_ [1][/][ exp] log(se)/N
|
159 |
+
|
160 |
+
Most close to our work, Chang et al. (2017) devised anP _uncertainty _ -based sampling scheme, called
|
161 |
+
_Active Bias, that chooses uncertain samples with high probability for the next batch. Active Bias_
|
162 |
+
maintains the history _i_ that stores all h(yi _xi) before the current iteration t (i.e., growing window),_
|
163 |
+
_H[t][−][1]_ _|_
|
164 |
+
where h(yi|xi) is the softmax probability of a given sample xi for its true label yi. Then, it measures
|
165 |
+
the uncertainty of the sample xi by computing the variance over all h(yi _xi) in_ _i_ and draws the
|
166 |
+
_|_ _H[t][−][1]_
|
167 |
+
next mini-batch samples based on the normalized probability P (xi _,_ _i_ ; ϵ) in Eq. (2), where ϵ is
|
168 |
+
_|D_ _H[t][−][1]_
|
169 |
+
the smoothness constant to prevent the low variance samples from never being selected again. As
|
170 |
+
mentioned earlier in Section 1, Active Bias slows down the training process because the oldest part in
|
171 |
+
the history _i_ no longer represents the current behavior of the network.
|
172 |
+
_H[t][−][1]_
|
173 |
+
|
174 |
+
_P_ (xi|D, Hi[t][−][1]; ϵ) = _Nj=1stdˆ_ ˆstdi(Hji[t]([−][1]j) +) + ϵ _ϵ_ _,_ _stdˆ_ (Hi[t][−][1]) = vuvar _h(yi|xi)_ + _[var] h(iyi|xi)2_
|
175 |
+
_H[t][−][1]_ ut _|H[t][−][1]|_ (2)
|
176 |
+
P
|
177 |
+
|
178 |
+
For the completeness of the survey, we include the recent studies on submodular batch selection.
|
179 |
+
Joseph et al. (2019) and Wang et al. (2019) designed their own submodular objectives that cover
|
180 |
+
diverse aspects, such as sample redundancy and sample representativeness, for more effective
|
181 |
+
batch selection. Differently from their work, we explore the issue of truly uncertain samples in
|
182 |
+
an orthogonal perspective. Our uncertainty measure can be easily injected into their submodular
|
183 |
+
optimization framework as a measure of sample informativeness.
|
184 |
+
|
185 |
+
In Section 5, we will confirm that Recency Bias outperforms Online Batch and Active Bias, which are
|
186 |
+
regarded as two state-of-the-art adaptive batch selection methods for deep learning.
|
187 |
+
|
188 |
+
3 _Recency Bias COMPONENTS_
|
189 |
+
|
190 |
+
3.1 CRITERION OF AN UNCERTAIN SAMPLE
|
191 |
+
|
192 |
+
The main challenge of Recency Bias is to identify the samples whose recent label predictions are
|
193 |
+
highly inconsistent, which are neither too easy nor too hard at the moment. Thus, we adopt the
|
194 |
+
_predictive uncertainty (Song et al., 2019) in Definition 3.1 that uses the information entropy (Chandler,_
|
195 |
+
1987) to measure the inconsistency of recent label predictions. Here, the sample with high predictive
|
196 |
+
uncertainty is regarded as uncertain and selected with high probability for the next mini-batch.
|
197 |
+
**Definition 3.1. (Predictive Uncertainty) Let ˆyt = Φ(xi, θt) be the predicted label of a sample xi at**
|
198 |
+
time t and Hxi (q) = {yˆt1 _, ˆyt2_ _, . . ., ˆytq_ _} be the label history of the sample xi that stores the predicted_
|
199 |
+
labels at the previous q times, where Φ is a neural network. The label history _xi_ (q) corresponds
|
200 |
+
_H_
|
201 |
+
to the sliding window of size q to compute the uncertainty of the sample xi. Next, p(yi _xi; q) is_
|
202 |
+
_|_
|
203 |
+
formulated such that it provides the probability of the label y ∈{1, 2, ..., k} estimated as the label of
|
204 |
+
the sample xi based on Hxi (q) as in Eq. (3), where [·] is the Iverson bracket[1].
|
205 |
+
|
206 |
+
_p(y_ _xi; q) =_ _yˆ∈Hxi_ (q)[[ˆ]y = y] (3)
|
207 |
+
_|_ P _xi_ (q)
|
208 |
+
|
209 |
+
_|H_ _|_
|
210 |
+
|
211 |
+
Then, to quantify the uncertainty of the sample xi, the predictive uncertainty F (xi; q) is defined by
|
212 |
+
Eq. (4), where δ is the standardization term to normalize the value to [0, 1].
|
213 |
+
|
214 |
+
|
215 |
+
_F_ (xi; q) = (1/δ)
|
216 |
+
_−_
|
217 |
+
|
218 |
+
|
219 |
+
_p(j_ _xi; q) log p(j_ _xi; q)_
|
220 |
+
_|_ _|_
|
221 |
+
_j=1_
|
222 |
+
|
223 |
+
X
|
224 |
+
|
225 |
+
|
226 |
+
(4)
|
227 |
+
|
228 |
+
|
229 |
+
_δ = −_ log (1/k) □
|
230 |
+
|
231 |
+
1The Iverson bracket [p] returns 1 if p is true; 0 otherwise.
|
232 |
+
|
233 |
+
|
234 |
+
-----
|
235 |
+
|
236 |
+
3.2 SAMPLING PROBABILITY FOR MINI-BATCH CONSTRUCTION
|
237 |
+
|
238 |
+
To construct next mini-batch samples, we assign the sampling probability according to the predictive
|
239 |
+
uncertainty in Definition 3.1. Motivated by Loshchilov & Hutter (2016), the sampling probability
|
240 |
+
of a given sample xi is exponentially decayed with its predictive uncertainty F (xi; q). In detail,
|
241 |
+
we adopt the quantization method (Chen & Wornell, 2001) and use the quantization index to decay
|
242 |
+
the sampling probability. The index is obtained by the simple quantizer Q in Eq. (5), where ∆ is
|
243 |
+
the quantization step size. Compared with the rank-based index (Loshchilov & Hutter, 2016), the
|
244 |
+
quantization index is known to well reflect the difference in actual values (Widrow et al., 1996).
|
245 |
+
|
246 |
+
_Q_ _F_ (xi; q) = 1 _F_ (xi; q) _/∆_ _, 0_ _F_ (xi; q) 1 (5)
|
247 |
+
_⌈_ _−_ _⌉_ _≤_ _≤_
|
248 |
+
|
249 |
+
In Eq. (5), we set ∆ to be 1/N such that the index is bounded to _N (the total number of samples)._
|
250 |
+
Then, the sampling probability P (xi ; se) is defined as in Eq. (6). The higher the predictive
|
251 |
+
_|D_
|
252 |
+
uncertainty, the smaller the quantization index. Therefore, a higher sampling probability is assigned
|
253 |
+
for uncertain samples in Eq. (6).
|
254 |
+
|
255 |
+
1/ exp log(se)/N _Q(F (xi;q))_
|
256 |
+
_P_ (xi|D; se) = _N_ _Q(F (xj_ ;q)) (6)
|
257 |
+
|
258 |
+
_j=1_ [1][/][ exp] log(se)/N
|
259 |
+
|
260 |
+
Meanwhile, it is known that using only some part of training data exacerbates the overfitting problemP
|
261 |
+
at a late stage of training (Loshchilov & Hutter, 2016; Zhou & Bilmes, 2018). Thus, to alleviate
|
262 |
+
the problem, we include more training samples as the training progresses by exponentially decaying
|
263 |
+
the selection pressure se as in Eq. (7). At each epoch e from e0 to eend, the selection pressure
|
264 |
+
_se exponentially decreases from se0 to 1. Because this technique gradually reduces the sampling_
|
265 |
+
probability gap between the most and the least uncertain samples, more diverse samples are selected
|
266 |
+
for the next mini-batch at a later epoch. When the selection pressure se becomes 1, the mini-batch
|
267 |
+
samples are randomly chosen from the entire dataset.
|
268 |
+
|
269 |
+
0
|
270 |
+
_se = se0_ exp log (1/se0 )/(eend − _e0)_ (7)
|
271 |
+
[][e][−][e]
|
272 |
+
|
273 |
+
4 _Recency Bias ALGORITHM_
|
274 |
+
|
275 |
+
**Algorithm 1 Recency Bias Algorithm**
|
276 |
+
|
277 |
+
INPUT: : data, epochs, b: batch size, q: window size, se0 : initial selection pressure, γ: warm-up
|
278 |
+
_D_
|
279 |
+
OUTPUT: θt: model parameter
|
280 |
+
|
281 |
+
1: t ← 1;
|
282 |
+
2: θt ← Initialize the model parameter;
|
283 |
+
3: for i = 1 to epochs do
|
284 |
+
4: /* Sampling Probability Derivation */
|
285 |
+
|
286 |
+
5: **if i > γ then**
|
287 |
+
|
288 |
+
6: _se ←_ Decay_Selection_Pressure(se0, i); /* Decaying se by Eq. (7) */
|
289 |
+
|
290 |
+
7: **for m = 1 to N do** /* Updating the index and the sampling probability in a batch */
|
291 |
+
|
292 |
+
8: _q_dict[xm] = Q_ _F_ (xm; q) ; /* By Eq. (5) */
|
293 |
+
|
294 |
+
|
295 |
+
9: _p_table ←_ Compute_Prob(q_dict, se); /* By Eq. (6) */
|
296 |
+
|
297 |
+
10: /* Network Training */
|
298 |
+
|
299 |
+
11: **for j = 1 to N/b do** /* Mini-batch */
|
300 |
+
|
301 |
+
12: **if i ≤** _γ then_ /* Warm-up */
|
302 |
+
|
303 |
+
13: (x1, y1), . . ., (xb, yb) Randomly select next mini-batch samples;
|
304 |
+
_{_ _} ←_
|
305 |
+
|
306 |
+
14: **else /* Adaptive batch selection */**
|
307 |
+
|
308 |
+
15: (x1, y1), . . ., (xb, yb) Select next mini-batch samples based on p_table;
|
309 |
+
_{_ _} ���_
|
310 |
+
|
311 |
+
16: _losses, labels_ Inference_Step( (x1, y1), . . ., (xb, yb),θt); /* Forward */
|
312 |
+
_←_ _{_ _}_
|
313 |
+
|
314 |
+
17: _θt+1 ←_ SGD_Step(losses, θt); /* Backward */
|
315 |
+
|
316 |
+
18: Update_Label_History(labels); /* By Definition 3.1 */
|
317 |
+
|
318 |
+
19: _t ←_ _t + 1;_
|
319 |
+
|
320 |
+
20: return θt;
|
321 |
+
|
322 |
+
Algorithm 1 describes the overall procedure of Recency Bias. The algorithm requires a warm-up
|
323 |
+
period of γ epochs because the quantization index for each sample is not confirmed yet. During
|
324 |
+
the warm-up period, which should be at least q epochs (γ ≥ _q) to obtain the label history of size_
|
325 |
+
|
326 |
+
|
327 |
+
-----
|
328 |
+
|
329 |
+
_q, randomly selected mini-batch samples are used for the network update (Lines 12–13). After the_
|
330 |
+
warm-up period, the algorithm decays the selection pressure se and updates not only the quantization
|
331 |
+
index but also the sampling probability in a batch at the beginning of each epoch (Lines 4–9).
|
332 |
+
Subsequently, the uncertain samples are selected for the next mini-batch according to the updated
|
333 |
+
sampling probability (Line 14–15), and then the label history is updated along with the network
|
334 |
+
update (Lines 16–19).
|
335 |
+
|
336 |
+
Overall, the key technical novelty of Recency Bias is to incorporate the notion of a sliding win_dow (Line 8) rather than a growing window into adaptive batch selection, thereby improving both_
|
337 |
+
training speed and generalization error.
|
338 |
+
|
339 |
+
**Time Complexity: The main “additional” cost of Recency Bias is the derivation of the sampling**
|
340 |
+
probability for each sample (Lines 4–9). Because only simple mathematical operations are needed
|
341 |
+
per sample, its time complexity is linear to the number of samples (i.e., O(N )), which is negligible
|
342 |
+
compared with that of the forward and backward steps of a complex network (Lines 16–17). Therefore,
|
343 |
+
we contend that Recency Bias does not add the complexity of an underlying optimization algorithm.
|
344 |
+
|
345 |
+
5 EVALUATION
|
346 |
+
|
347 |
+
We empirically show the improvement of Recency Bias over not only Random Batch (baseline) but also
|
348 |
+
_Online Batch (Loshchilov & Hutter, 2016) and Active Bias (Chang et al., 2017), which are two state-_
|
349 |
+
of-the-art adaptive batch selections. In particular, we elaborate on the effect of the sliding window
|
350 |
+
approach (Recency Bias) compared with the growing window approach (Active Bias). Random Batch
|
351 |
+
selects next mini-batch samples uniformly at random from the entire dataset. Online Batch selects hard
|
352 |
+
samples based on the rank of the loss computed from previous epochs. Active Bias selects uncertain
|
353 |
+
samples with high variance of true label probabilities in the growing window. All the algorithms
|
354 |
+
were implemented using TensorFlow 1.8.0 and executed using a single NVIDIA Titan Volta GPU.
|
355 |
+
[For reproducibility, we provide the source code at https://github.com/anonymized.](https://github.com/anonymized)
|
356 |
+
|
357 |
+
Image classification and fine-tuning tasks were performed to validate the superiority of Recency Bias.
|
358 |
+
Because fine-tuning is used to quickly adapt to a new dataset, it is suitable to reap the benefit of fast
|
359 |
+
training speed. In support of reliable evaluation, we repeated every task thrice and reported the average
|
360 |
+
and standard error of the best test errors. The best test error in a given time has been widely used for
|
361 |
+
the studies on fast and accurate training (Katharopoulos & Fleuret, 2018; Loshchilov & Hutter, 2016).
|
362 |
+
|
363 |
+
5.1 ANALYSIS ON SELECTED MINI-BATCH SAMPLES
|
364 |
+
|
365 |
+
For an in-depth analysis on selected samples, we plot the loss distribution of mini-batch samples
|
366 |
+
selected from CIFAR-10 by four different strategies in Figure 3. (i) The distribution of Online Batch
|
367 |
+
is the most skewed toward high loss by the design principle of selecting hard samples. (ii) Active Bias
|
368 |
+
emphasizes moderately hard samples at an early training stage in considering that its loss distribution
|
369 |
+
lies between those of Random Batch and Online Batch. However, owing to the outdated predictions
|
370 |
+
caused by the growing window, the proportion of easy samples with low loss increases at a late
|
371 |
+
training stage. These easy samples, which are misclassified as uncertain at that stage, tend to make the
|
372 |
+
convergence of training slow down. (iii) In contrast to Active Bias, by virtue of the sliding window,
|
373 |
+
the distribution of Recency Bias lies between those of Random Batch and Online Batch regardless of
|
374 |
+
the training stage. Consequently, Recency Bias continues to highlight the moderately hard samples,
|
375 |
+
which are likely to be informative, during the training process.
|
376 |
+
|
377 |
+
Random Batch
|
378 |
+
Online Batch
|
379 |
+
Active Bias
|
380 |
+
(Growing window)
|
381 |
+
Recency Bias
|
382 |
+
(Sliding window)
|
383 |
+
|
384 |
+
Loss (Log-scale) Loss (Log-scale)
|
385 |
+
|
386 |
+
|
387 |
+
(a) Early Stage (30%). (b) Late Stage (70%).
|
388 |
+
|
389 |
+
Figure 3: The loss distribution of mini-batch samples selected by four batch selection strategies: (a)
|
390 |
+
and (b) show the loss distribution at the 30% and 70% of total training epochs, respectively.
|
391 |
+
|
392 |
+
|
393 |
+
-----
|
394 |
+
|
395 |
+
5.2 TASK I: IMAGE CLASSIFICATION
|
396 |
+
|
397 |
+
**Experiment Setting: We trained DenseNet (L=40, k=12) and ResNet (L=50) with a momentum**
|
398 |
+
optimizer and an SGD optimizer on three benchmark datasets: MNIST (10 classes)[2], classification
|
399 |
+
of handwritten digits (LeCun, 1998), and CIFAR-10 (10 classes)[3] and CIFAR-100 (100 classes)[3],
|
400 |
+
classification of a subset of 80 million categorical images (Krizhevsky et al., 2014). Specifically, we
|
401 |
+
used data augmentation, batch normalization, a momentum of 0.9, and a batch size of 128. As for the
|
402 |
+
algorithm parameters, we fixed the window size q = 10 and the initial selection pressure se0 = 100,[4]
|
403 |
+
which were the best values found by the grid search (see Appendix A for details). The warm-up
|
404 |
+
epoch γ was set to be 15. To reduce the performance variance caused by randomly initialized model
|
405 |
+
parameters, all parameters were shared by all algorithms during the warm-up period. Regarding
|
406 |
+
the training schedule, we trained the network for 40, 000 iterations and used an initial learning rate
|
407 |
+
of 0.1, which was divided by 10 at 50% and 75% of the total number of training iterations.
|
408 |
+
|
409 |
+
**Results: Figure 4 shows the convergence curves of training loss and test error for four batch selection**
|
410 |
+
strategies using DenseNet and a momentum optimizer. In order to highlight the improvement of
|
411 |
+
_Recency Bias over the baseline (Random Batch), their lines are dark colored. The best test errors in_
|
412 |
+
Figures 4(b), 4(d), and 4(f) are summarized on the left side of Table 1.
|
413 |
+
|
414 |
+
In general, Recency Bias achieved the most accurate network while accelerating the training process
|
415 |
+
on all datasets. The training loss of Recency Bias converged faster (Figures 4(a), 4(c), and 4(e))
|
416 |
+
without the increase in the generalization error, thereby achieving the lower test error (Figures 4(b),
|
417 |
+
4(d), and 4(f)). In contrast, the test error of Online Batch was not the best even if its training loss
|
418 |
+
converged the fastest among all strategies. As the training difficulty increased from CIFAR-10 to
|
419 |
+
CIFAR-100, the test error of Online Batch became even worse than that of Random Batch. That
|
420 |
+
is, emphasizing hard samples accelerated the training step but made the network overfit to hard
|
421 |
+
samples. Meanwhile, Active Bias was prone to make the network better generalized on test data.
|
422 |
+
In CIFAR-10, despite its highest training loss, the test error of Active Bias was better than that of
|
423 |
+
_Random Batch. However, Active Bias slowed down the training process because of the limitation_
|
424 |
+
of growing windows, as discussed in Section 5.1. We note that, although both Recency Bias and
|
425 |
+
_Active Bias exploited uncertain samples, only Recency Bias based on sliding windows succeeded_
|
426 |
+
to not only speed up the training process but also reduce the generalization error.
|
427 |
+
|
428 |
+
The results of the best test error for ResNet or an SGD optimizer are summarized in Tables 1 and
|
429 |
+
2 (see Appendix B for more details). Regardless of a neural network and an optimizer, Recency
|
430 |
+
_Bias achieved the lowest test error except in MNIST with an SGD optimizer. The improvement of_
|
431 |
+
_Recency Bias over the others was higher with an SGD optimizer than with a momentum optimizer._
|
432 |
+
|
433 |
+
Table 1: The best test errors (%) of four batch selection strategies using DenseNet.
|
434 |
+
|
435 |
+
|Optimizer|Momentum in Figure 4|Col3|Col4|SGD in Figure 7 (Appendix B.1)|Col6|Col7|
|
436 |
+
|---|---|---|---|---|---|---|
|
437 |
+
|Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
|
438 |
+
|Random Batch|0.527 ± 0.03|7.33 ± 0.09|28.0 ± 0.16|1.23 ± 0.03|14.9 ± 0.09|40.2 ± 0.06|
|
439 |
+
|Online Batch|0.514 ± 0.01|7.00 ± 0.10|28.4 ± 0.25|0.765 ± 0.02|13.5 ± 0.02|40.7 ± 0.12|
|
440 |
+
|Active Bias|0.616 ± 0.03|7.07 ± 0.04|27.9 ± 0.11|0.679 ± 0.02|14.2 ± 0.25|42.9 ± 0.05|
|
441 |
+
|Recency Bias|0.490 ± 0.02|6.60 ± 0.02|27.1 ± 0.19|0.986 ± 0.06|13.2 ± 0.11|38.7 ± 0.11|
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
Table 2: The best test errors (%) of four batch selection strategies using ResNet.
|
446 |
+
|
447 |
+
|Optimizer|Momentum in Figure 8 (Appendix B.2)|Col3|Col4|SGD in Figure 9 (Appendix B.3)|Col6|Col7|
|
448 |
+
|---|---|---|---|---|---|---|
|
449 |
+
|Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
|
450 |
+
|Random Batch|0.636 ± 0.04|10.2 ± 0.12|33.2 ± 0.07|1.16 ± 0.03|12.7 ± 0.09|40.1 ± 0.16|
|
451 |
+
|Online Batch|0.666 ± 0.05|10.1 ± 0.05|33.4 ± 0.01|0.890 ± 0.03|12.2 ± 0.08|40.7 ± 0.09|
|
452 |
+
|Active Bias|0.613 ± 0.04|10.6 ± 0.08|34.2 ± 0.07|0.804 ± 0.01|13.5 ± 0.07|45.6 ± 0.07|
|
453 |
+
|Recency Bias|0.607 ± 0.01|9.79 ± 0.04|32.4 ± 0.04|0.972 ± 0.03|11.6 ± 0.09|38.9 ± 0.14|
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
[2http://yann.lecun.com/exdb/mnist](http://yann.lecun.com/exdb/mnist)
|
458 |
+
[3https://www.cs.toronto.edu/~kriz/cifar.html](https://www.cs.toronto.edu/~kriz/cifar.html)
|
459 |
+
4Online Batch also used the same decaying selection pressure value.
|
460 |
+
|
461 |
+
|
462 |
+
-----
|
463 |
+
|
464 |
+
|Col1|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|
|
465 |
+
|---|---|---|---|---|---|---|---|---|---|
|
466 |
+
||Random Batch Online||||Batch Active Bias Recency Bias|||||
|
467 |
+
|E-01 E-02 E-03|||||3.6% Error 1.2% Test|||||
|
468 |
+
|||||||||||
|
469 |
+
|||||||||||
|
470 |
+
|||||||||||
|
471 |
+
|||||||||||
|
472 |
+
|
473 |
+
|
474 |
+
2125 4250 6375 8500
|
475 |
+
|
476 |
+
Time (s)
|
477 |
+
|
478 |
+
|
479 |
+
2125 4250 6375 8500
|
480 |
+
|
481 |
+
Time (s)
|
482 |
+
|
483 |
+
|
484 |
+
0.90.80.70.60.50.40.30.20.110
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
|Col1|Col2|Col3|Col4|
|
490 |
+
|---|---|---|---|
|
491 |
+
|||||
|
492 |
+
|||||
|
493 |
+
|
494 |
+
|0|Col2|Col3|Col4|
|
495 |
+
|---|---|---|---|
|
496 |
+
|||||
|
497 |
+
|||||
|
498 |
+
|
499 |
+
|
500 |
+
|(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
|
501 |
+
|---|---|---|---|---|
|
502 |
+
|(a) MNIST Training Loss. (b) MNIST Test Error. 16E-01 20 40 60 80 26.0%10 0 4E-01 Error 13.0% Test 4E-02 0E-03 6.5% 0 2500 5000 7500 10000 0 2500 5000 7500 100 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 4E+00 54.0% Error 0E-01 Test|||||
|
503 |
+
||||||
|
504 |
+
|
505 |
+
|
506 |
+
2500 5000 7500 10000
|
507 |
+
|
508 |
+
|
509 |
+
2.4E+00
|
510 |
+
|
511 |
+
6.0E-01
|
512 |
+
|
513 |
+
|
514 |
+
1.5E-01
|
515 |
+
|
516 |
+
|
517 |
+
27.0%
|
518 |
+
|
519 |
+
|
520 |
+
2500 5000 7500 10000
|
521 |
+
|
522 |
+
Time (s)
|
523 |
+
|
524 |
+
|
525 |
+
Time (s)
|
526 |
+
|
527 |
+
|
528 |
+
(e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
|
529 |
+
|
530 |
+
Figure 4: Convergence curves of four batch selection strategies using DenseNet with momentum.
|
531 |
+
|
532 |
+
|
533 |
+
5.3 TASK II: FINE-TUNING
|
534 |
+
|
535 |
+
**Experiment Setting: We prepared DenseNet (L=121, k=32) previously trained on ImageNet (Deng**
|
536 |
+
et al., 2009) and then fine-tuned the network on two benchmark datasets: MIT-67 (67 classes)[5],
|
537 |
+
classification of indoor scenes (Quattoni & Torralba, 2009), and Food-100 (100 classes)[6], classification of popular foods in Japan (Kawano & Yanai, 2014). After replacing the last classification
|
538 |
+
layer, the network was trained end-to-end for 50 epochs with a batch size 32 and a constant learning
|
539 |
+
rate 2 × 10[−][4]. Data augmentation was not applied here. The other configurations were the same
|
540 |
+
as those in Section 5.2.
|
541 |
+
|
542 |
+
**Results on Test Error: Figure 5 shows the convergence curves of training loss and test error for**
|
543 |
+
the fine-tuning task on MIT-67 and Food-100. Overall, all convergence curves showed similar trends
|
544 |
+
to those of the classification task in Figure 4. Only Recency Bias converged faster than Random
|
545 |
+
_Batch in both training loss and test error. Online Batch converged the fastest in training loss, but_
|
546 |
+
its test error was rather higher than Random Batch owing to the overfitting. Active Bias converged the
|
547 |
+
|
548 |
+
|
549 |
+
[5http://web.mit.edu/torralba/www/indoor.html](http://web.mit.edu/torralba/www/indoor.html)
|
550 |
+
[6http://foodcam.mobi/dataset100.html](http://foodcam.mobi/dataset100.html)
|
551 |
+
|
552 |
+
|
553 |
+
-----
|
554 |
+
|
555 |
+
|Col1|Col2|Col3|Col4|Col5|
|
556 |
+
|---|---|---|---|---|
|
557 |
+
||||||
|
558 |
+
|Time Redu|ction: 24.6|%|||
|
559 |
+
|
560 |
+
|
561 |
+
Random Batch Online Batch Active Bias Recency Bias
|
562 |
+
|
563 |
+
1.9E+00 39.0%
|
564 |
+
|
565 |
+
6.3E-012.1E-01 Test Error 35.0%31.0%
|
566 |
+
|
567 |
+
Training Loss
|
568 |
+
|
569 |
+
Time Reduction: 24.6%
|
570 |
+
|
571 |
+
7.0E-02 27.0%
|
572 |
+
|
573 |
+
0 1500 3000 4500 6000 0 1500 3000 4500 6000
|
574 |
+
|
575 |
+
0.90.80.70.60.50.40.30.20.110 (a) MIT-67 Training Loss.Time (s) (b) MIT-67 Test Error.Time (s)
|
576 |
+
|
577 |
+
1.6E+001 20 40 60 80 44.0%10
|
578 |
+
|
579 |
+
0
|
580 |
+
|
581 |
+
40.0%
|
582 |
+
|
583 |
+
|
584 |
+
0.90.80.70.60.50.40.30.20.110
|
585 |
+
|
586 |
+
|
587 |
+
8.0E-01
|
588 |
+
|
589 |
+
4.0E-01
|
590 |
+
|
591 |
+
|
592 |
+
36.0%
|
593 |
+
|
594 |
+
32.0%
|
595 |
+
|
596 |
+
|
597 |
+
2.0E-01
|
598 |
+
|
599 |
+
|20|4|0|60|Col5|
|
600 |
+
|---|---|---|---|---|
|
601 |
+
||||||
|
602 |
+
||||||
|
603 |
+
|
604 |
+
|0|Col2|Col3|Col4|Col5|
|
605 |
+
|---|---|---|---|---|
|
606 |
+
|0|||||
|
607 |
+
||||||
|
608 |
+
|Time Redu|ction: 26.1|%|||
|
609 |
+
|
610 |
+
|
611 |
+
2000 4000 6000 8000
|
612 |
+
|
613 |
+
Time (s)
|
614 |
+
|
615 |
+
|
616 |
+
2000 4000 6000 8000
|
617 |
+
|
618 |
+
Time (s)
|
619 |
+
|
620 |
+
|
621 |
+
(c) Food-100 Training Loss. (d) Food-100 Test Error.
|
622 |
+
|
623 |
+
Figure 5: Convergence curves for fine-tuning on two benchmark datasets.
|
624 |
+
|
625 |
+
|
626 |
+
Table 3: Recency Bias’s reduction in training time over other batch selection strategies.
|
627 |
+
|
628 |
+
|
629 |
+
|Method|MIT-67|FOOD-100|
|
630 |
+
|---|---|---|
|
631 |
+
|Random Batch|(5, 218 −3, 936)/5, 218 × 100 = 24.6%|(7, 263 −5, 365)/7, 263 × 100 = 26.1%|
|
632 |
+
|Online Batch|(6, 079 −3, 823)/6, 079 × 100 = 37.1%|(8, 333 −3, 685)/8, 333 × 100 = 55.8%|
|
633 |
+
|Active Bias|(5, 738 −3, 032)/5, 738 × 100 = 47.2%|(7, 933 −3, 227)/7, 933 × 100 = 59.3%|
|
634 |
+
|
635 |
+
|
636 |
+
slowest in both training loss and test error. Quantitatively, compared with Random Batch, Recency
|
637 |
+
_Bias reduced the test error by 2.88% and 1.81% in MIT-67 and Food-100, respectively._
|
638 |
+
|
639 |
+
**Results on Training Time: Moreover, to assess the performance gain in training time, we computed**
|
640 |
+
the reduction in the training time taken to reach the same error. For example, in Figure 5(b), the
|
641 |
+
best test error of 28.8% achieved in 5, 218 seconds by Random Batch could be achieved only in
|
642 |
+
3, 936 seconds by Recency Bias; thus, Recency Bias improved the training time by 24.6%. Table
|
643 |
+
3 summarizes the reduction in the training time of Recency Bias over three other batch selection
|
644 |
+
strategies. Notably, Recency Bias improved the training time by 24.6%–47.2% and 26.1%–59.3% in
|
645 |
+
fine-tuning MIT-67 and FOOD-100 datasets, respectively.
|
646 |
+
|
647 |
+
6 CONCLUSION
|
648 |
+
|
649 |
+
|
650 |
+
In this paper, we presented a novel adaptive batch selection algorithm called Recency Bias that
|
651 |
+
emphasizes predictively uncertain samples for accelerating the training of neural networks. Toward
|
652 |
+
this goal, the predictive uncertainty of each sample is evaluated using its recent label predictions
|
653 |
+
managed by a sliding window of a fixed size. Then, uncertain samples at the moment are selected with
|
654 |
+
high probability for the next mini-batch. We conducted extensive experiments on both classification
|
655 |
+
and fine-tuning tasks. The results showed that Recency Bias is effective in reducing the training
|
656 |
+
time as well as the best test error. It was worthwhile to note that using all historical observations to
|
657 |
+
estimate the uncertainty has the side effect of slowing down the training process. Overall, a merger of
|
658 |
+
uncertain samples and sliding windows greatly improves the power of adaptive batch selection.
|
659 |
+
|
660 |
+
|
661 |
+
-----
|
662 |
+
|
663 |
+
REFERENCES
|
664 |
+
|
665 |
+
Yoshua Bengio, Jérôme Louradour, Ronan Collobert, and Jason Weston. Curriculum learning. In
|
666 |
+
_ICML, pp. 41–48, 2009._
|
667 |
+
|
668 |
+
David Chandler. Introduction to modern statistical mechanics. Oxford University Press, 1987.
|
669 |
+
|
670 |
+
Haw-Shiuan Chang, Erik Learned-Miller, and Andrew McCallum. Active Bias: Training more
|
671 |
+
accurate neural networks by emphasizing high variance samples. In NeurIPS, pp. 1002–1012,
|
672 |
+
2017.
|
673 |
+
|
674 |
+
Brian Chen and Gregory W Wornell. Quantization index modulation: A class of provably good
|
675 |
+
methods for digital watermarking and information embedding. IEEE Trans. on Information Theory,
|
676 |
+
47(4):1423–1443, 2001.
|
677 |
+
|
678 |
+
Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale
|
679 |
+
hierarchical image database. In CVPR, pp. 248–255, 2009.
|
680 |
+
|
681 |
+
Yang Fan, Fei Tian, Tao Qin, and Tie-Yan Liu. Neural data filter for bootstrapping stochastic gradient
|
682 |
+
descent. In ICLR, 2017.
|
683 |
+
|
684 |
+
Bo Han, Quanming Yao, Xingrui Yu, Gang Niu, Miao Xu, Weihua Hu, Ivor Tsang, and Masashi
|
685 |
+
Sugiyama. Co-teaching: Robust training of deep neural networks with extremely noisy labels. In
|
686 |
+
_NeurIPS, pp. 8527–8537, 2018._
|
687 |
+
|
688 |
+
KJ Joseph, Krishnakant Singh, Vineeth N Balasubramanian, et al. Submodular batch selection for
|
689 |
+
training deep neural networks. In IJCAI, pp. 2677–3683, 2019.
|
690 |
+
|
691 |
+
Angelos Katharopoulos and François Fleuret. Not all samples are created equal: Deep learning with
|
692 |
+
importance sampling. In ICML, pp. 2525–2534, 2018.
|
693 |
+
|
694 |
+
Y. Kawano and K. Yanai. Food image recognition with deep convolutional features. In UbiComp,
|
695 |
+
2014.
|
696 |
+
|
697 |
+
Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. CIFAR-10 and CIFAR-100 datasets, 2014.
|
698 |
+
[https://www.cs.toronto.edu/~kriz/cifar.html.](https://www.cs.toronto.edu/~kriz/cifar.html)
|
699 |
+
|
700 |
+
M Pawan Kumar, Benjamin Packer, and Daphne Koller. Self-paced learning for latent variable
|
701 |
+
models. In NeurIPS, pp. 1189–1197, 2010.
|
702 |
+
|
703 |
+
[Yann LeCun. The MNIST database of handwritten digits, 1998. http://yann.lecun.com/](http://yann.lecun.com/exdb/mnist)
|
704 |
+
[exdb/mnist.](http://yann.lecun.com/exdb/mnist)
|
705 |
+
|
706 |
+
Ilya Loshchilov and Frank Hutter. Online batch selection for faster training of neural networks. In
|
707 |
+
_ICLR, 2016._
|
708 |
+
|
709 |
+
Ariadna Quattoni and Antonio Torralba. Recognizing indoor scenes. In CVPR, pp. 413–420, 2009.
|
710 |
+
|
711 |
+
Mrinmaya Sachan and Eric Xing. Easy questions first? A case study on curriculum learning for
|
712 |
+
question answering. In ACL, pp. 453–463, 2016.
|
713 |
+
|
714 |
+
Abhinav Shrivastava, Abhinav Gupta, and Ross Girshick. Training region-based object detectors
|
715 |
+
with online hard example mining. In CVPR, pp. 761–769, 2016.
|
716 |
+
|
717 |
+
Hwanjun Song, Minseok Kim, and Jae-Gil Lee. SELFIE: Refurbishing unclean samples for robust
|
718 |
+
deep learning. In ICML, pp. 5907–5915, 2019.
|
719 |
+
|
720 |
+
Luis Torgo. Data mining with R: learning with case studies. Chapman and Hall/CRC, 2011.
|
721 |
+
|
722 |
+
Yulia Tsvetkov, Manaal Faruqui, Wang Ling, Brian MacWhinney, and Chris Dyer. Learning the
|
723 |
+
curriculum with bayesian optimization for task-specific word representation learning. In ACL, pp.
|
724 |
+
130–139, 2016.
|
725 |
+
|
726 |
+
Shengjie Wang, Wenruo Bai, Chandrashekhar Lavania, and Jeff Bilmes. Fixing mini-batch sequences
|
727 |
+
with hierarchical robust partitioning. In AISTATS, pp. 3352–3361, 2019.
|
728 |
+
|
729 |
+
|
730 |
+
-----
|
731 |
+
|
732 |
+
Bernard Widrow, Istvan Kollar, and Ming-Chang Liu. Statistical theory of quantization. IEEE
|
733 |
+
_Transactions on instrumentation and measurement, 45(2):353–361, 1996._
|
734 |
+
|
735 |
+
Tianyi Zhou and Jeff Bilmes. Minimax curriculum learning: Machine teaching with desirable
|
736 |
+
difficulties and scheduled diversity. In ICLR, 2018.
|
737 |
+
|
738 |
+
|
739 |
+
-----
|
740 |
+
|
741 |
+
A HYPERPARAMETER SELECTION
|
742 |
+
|
743 |
+
_Recency Bias receives the two hyperparameters: (i) the initial selection pressure se0 that determines_
|
744 |
+
the sampling probability gap between the most and the least uncertain samples and (ii) the window
|
745 |
+
size q that determines how many recent label predictions are involved in predicting the uncertainty.
|
746 |
+
To decide the best hyperparameters, we trained ResNet (L=50) on CIFAR-10 and CIFAR-100 with a
|
747 |
+
momentum optimizer. For hyperparameters selection, the two hyperparameters were chosen in a grid
|
748 |
+
_se0_ 1, 10, 100, 1000 and q 5, 10, 15 .
|
749 |
+
_∈{_ _}_ _∈{_ _}_
|
750 |
+
|
751 |
+
|
752 |
+
|
753 |
+
|Window|Size|Col3|
|
754 |
+
|---|---|---|
|
755 |
+
|q=5 q=10 q=15|||
|
756 |
+
|
757 |
+
|
758 |
+
10.4% 33.5%
|
759 |
+
|
760 |
+
Window Size
|
761 |
+
|
762 |
+
10.1% 33.0% q=5
|
763 |
+
|
764 |
+
q=10
|
765 |
+
|
766 |
+
9.8% 32.5%
|
767 |
+
|
768 |
+
Best Test Error q=15
|
769 |
+
|
770 |
+
9.5% 32.0%
|
771 |
+
|
772 |
+
1 10 100 1000 1 10 100 1000
|
773 |
+
|
774 |
+
Initial Selection Pressure (𝑆𝑒0) Initial Selection Pressure (𝑆𝑒0)
|
775 |
+
|
776 |
+
|
777 |
+
(a) CIFAR-10. (b) CIFAR-100.
|
778 |
+
|
779 |
+
Figure 6: Grid search on CIFAR-10 and CIFAR-100 datasets using ResNet.
|
780 |
+
|
781 |
+
Figure 6 shows the test errors of Recency Bias obtained by the grid search on the two datasets.
|
782 |
+
Regarding the initial selection pressure se0, the lowest test error was typically achieved when the
|
783 |
+
_se0 value was 100. As for the window size q, the test error was almost always the lowest when the q_
|
784 |
+
value was 10. Similar trends were observed for the other combinations of a neural network and an
|
785 |
+
optimizer. Therefore, in all experiments, we set se0 to be 100 and q to be 10.
|
786 |
+
|
787 |
+
|
788 |
+
-----
|
789 |
+
|
790 |
+
|GENERALIZATION OF Recency Bias|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|Col11|Col12|
|
791 |
+
|---|---|---|---|---|---|---|---|---|---|---|---|
|
792 |
+
|CONVERGENCE CURVES USING DENSENET WITH SGD 7 shows the convergence curves of training loss and test error for four batch selection strate DenseNet and an SGD optimizer, which corresponds to the right side of Table 1.||||||||||||
|
793 |
+
||eNet and an SGD optimizer, whic|||||||||||
|
794 |
+
||Random Batch Online|||||Batch Active Bias Recency Bias||||||
|
795 |
+
|E-01 E-02 E-02||||||4.8% 2.4% Error Test 1.2%||||||
|
796 |
+
|||||||||||||
|
797 |
+
|||||||||||||
|
798 |
+
|||||||||||||
|
799 |
+
|
800 |
+
|
801 |
+
2000 4000 6000 8000
|
802 |
+
|
803 |
+
Time (s)
|
804 |
+
|
805 |
+
|
806 |
+
2000 4000 6000 8000
|
807 |
+
|
808 |
+
Time (s)
|
809 |
+
|
810 |
+
|
811 |
+
0.90.80.70.60.50.40.30.20.110
|
812 |
+
|
813 |
+
|
814 |
+
|
815 |
+
|
816 |
+
|(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
|
817 |
+
|---|---|---|---|---|
|
818 |
+
|(a) MNIST Training Loss. (b) MNIST Test Error. 10E+00 20 40 60 80 48.0%10 0 2E-01 Error 24.0% Test 6E-01|||||
|
819 |
+
||||||
|
820 |
+
||||||
|
821 |
+
|
822 |
+
|
823 |
+
3.2E+00
|
824 |
+
|
825 |
+
|
826 |
+
12.0%
|
827 |
+
|
828 |
+
|Col1|Col2|Col3|Col4|
|
829 |
+
|---|---|---|---|
|
830 |
+
|||||
|
831 |
+
|||||
|
832 |
+
|
833 |
+
|
834 |
+
2500 5000 7500 10000
|
835 |
+
|
836 |
+
Time (s)
|
837 |
+
|
838 |
+
|
839 |
+
2500 5000 7500 10000
|
840 |
+
|
841 |
+
|
842 |
+
(c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
|
843 |
+
|
844 |
+
70.0%
|
845 |
+
|
846 |
+
|
847 |
+
1.6E+00
|
848 |
+
|
849 |
+
8.0E-01
|
850 |
+
|
851 |
+
|
852 |
+
35.0%
|
853 |
+
|
854 |
+
|(c) CIFAR-10 Training Loss.|Col2|Col3|Col4|
|
855 |
+
|---|---|---|---|
|
856 |
+
|||||
|
857 |
+
|||||
|
858 |
+
|||||
|
859 |
+
|
860 |
+
|Time (s)|Col2|Col3|Col4|
|
861 |
+
|---|---|---|---|
|
862 |
+
|(d) CIFAR-10 Test Error.||||
|
863 |
+
|||||
|
864 |
+
|
865 |
+
|
866 |
+
2500 5000 7500 10000
|
867 |
+
|
868 |
+
Time (s)
|
869 |
+
|
870 |
+
|
871 |
+
2500 5000 7500 10000
|
872 |
+
|
873 |
+
Time (s)
|
874 |
+
|
875 |
+
|
876 |
+
(e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
|
877 |
+
|
878 |
+
Figure 7: Convergence curves of four batch selection strategies using DenseNet with SGD.
|
879 |
+
|
880 |
+
|
881 |
+
-----
|
882 |
+
|
883 |
+
|Col1|et and a momentum optimizer, w|Col3|Col4|Col5|
|
884 |
+
|---|---|---|---|---|
|
885 |
+
||Random Batch Online|||Batch Active Bias Recency Bias|
|
886 |
+
||||||
|
887 |
+
||||||
|
888 |
+
||||||
|
889 |
+
||||||
|
890 |
+
|
891 |
+
|
892 |
+
|
893 |
+
0.90.80.70.60.50.40.30.20.110
|
894 |
+
|
895 |
+
|
896 |
+
|
897 |
+
|
898 |
+
|20|40|60|
|
899 |
+
|---|---|---|
|
900 |
+
||||
|
901 |
+
||||
|
902 |
+
||||
|
903 |
+
|
904 |
+
|0|Col2|Col3|
|
905 |
+
|---|---|---|
|
906 |
+
||||
|
907 |
+
||||
|
908 |
+
|
909 |
+
|
910 |
+
|
911 |
+
|
912 |
+
|CONVERGENCE CURVES USING RESNET WITH MOMENTUM e 8 shows the convergence curves of training loss and test error for four batch selection strate ResNet and a momentum optimizer, which corresponds to the left side of Table 2. Random Batch Online Batch Active Bias Recency Bias 5E-01 4.5% 9E-02 Error 1.5% Test 4E-03 0E-04 0.5% 0 2300 4600 6900 0 2300 4600 69 Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 17E+00 20 40 60 80 36.0%10 0 5E-01 Error 18.0% Test 5E-02 0E-03 9.0% 0 2900 5800 8700 0 2900 5800 870 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 3E+00 64.0% 2E-01 Error Test 2E-01|Col2|Col3|Col4|
|
913 |
+
|---|---|---|---|
|
914 |
+
|||||
|
915 |
+
|
916 |
+
|
917 |
+
2900 5800 8700
|
918 |
+
|
919 |
+
Time (s)
|
920 |
+
|
921 |
+
|
922 |
+
2900 5800 8700
|
923 |
+
|
924 |
+
Time (s)
|
925 |
+
|
926 |
+
|
927 |
+
(e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
|
928 |
+
|
929 |
+
Figure 8: Convergence curves of four batch selection strategies using ResNet with momentum.
|
930 |
+
|
931 |
+
|
932 |
+
-----
|
933 |
+
|
934 |
+
|CONVERGENCE CURVES USING RESNET WITH SGD 9 shows the convergence curves of training loss and test error for four batch selection strate ResNet and an SGD optimizer, which corresponds to the right side of Table 2.|Col2|Col3|Col4|Col5|Col6|Col7|Col8|
|
935 |
+
|---|---|---|---|---|---|---|---|
|
936 |
+
||ows the convergence curves of trai et and an SGD optimizer, which|||||||
|
937 |
+
||Random Batch Online|||Batch Active Bias Recency Bias||||
|
938 |
+
|E-01 E-02 E-02||||6.3% Error 2.1% Test||||
|
939 |
+
|||||||||
|
940 |
+
|||||||||
|
941 |
+
|||||||||
|
942 |
+
|||||||||
|
943 |
+
|
944 |
+
|
945 |
+
2300 4600 6900
|
946 |
+
|
947 |
+
|
948 |
+
2300 4600 6900
|
949 |
+
|
950 |
+
|Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 1 20 40 60 80 44.0%10 4E+00 0 5E-01 Error 22.0% Test 5E-01|Col2|Col3|Col4|
|
951 |
+
|---|---|---|---|
|
952 |
+
|||||
|
953 |
+
|||||
|
954 |
+
|
955 |
+
|
956 |
+
|
957 |
+
2900 5800 8700
|
958 |
+
|
959 |
+
|
960 |
+
0.90.80.70.60.50.40.30.20.110
|
961 |
+
|
962 |
+
|
963 |
+
0
|
964 |
+
|
965 |
+
22.0%
|
966 |
+
|
967 |
+
Test Error
|
968 |
+
|
969 |
+
11.0%
|
970 |
+
|
971 |
+
|
972 |
+
1.4E+001
|
973 |
+
|
974 |
+
4.5E-01
|
975 |
+
|
976 |
+
|
977 |
+
1.5E-01
|
978 |
+
|
979 |
+
5.0E-02
|
980 |
+
|
981 |
+
|20|40|60|
|
982 |
+
|---|---|---|
|
983 |
+
||||
|
984 |
+
||||
|
985 |
+
||||
|
986 |
+
|
987 |
+
|
988 |
+
2900 5800 8700
|
989 |
+
|
990 |
+
|
991 |
+
|
992 |
+
(c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
|
993 |
+
|
994 |
+
76.0%
|
995 |
+
|
996 |
+
|
997 |
+
3.6E+00
|
998 |
+
|
999 |
+
1.2E+00
|
1000 |
+
|
1001 |
+
|
1002 |
+
4.0E-01
|
1003 |
+
|
1004 |
+
|
1005 |
+
38.0%
|
1006 |
+
|
1007 |
+
|Time (s)|Col2|Col3|
|
1008 |
+
|---|---|---|
|
1009 |
+
|(c) CIFAR-10 Training Loss.|||
|
1010 |
+
||||
|
1011 |
+
||||
|
1012 |
+
|
1013 |
+
|Time (s)|Col2|Col3|
|
1014 |
+
|---|---|---|
|
1015 |
+
|(d) CIFAR-10 Test Error.|||
|
1016 |
+
||||
|
1017 |
+
|
1018 |
+
|
1019 |
+
2900 5800 8700
|
1020 |
+
|
1021 |
+
Time (s)
|
1022 |
+
|
1023 |
+
|
1024 |
+
2900 5800 8700
|
1025 |
+
|
1026 |
+
Time (s)
|
1027 |
+
|
1028 |
+
|
1029 |
+
(e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
|
1030 |
+
|
1031 |
+
Figure 9: Convergence curves of four batch selection strategies using ResNet with SGD.
|
1032 |
+
|
1033 |
+
|
1034 |
+
-----
|
1035 |
+
|
ai_scientist/fewshot_examples/attention.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"review": "{\n \"Summary\": \"The paper proposes the Transformer, a novel neural network architecture that relies entirely on self-attention mechanisms, eschewing traditional recurrent and convolutional layers. This innovation allows the model to achieve state-of-the-art results in machine translation tasks with significant improvements in both training efficiency and translation quality. The paper includes detailed descriptions of the model architecture, including multi-head attention and positional encodings, as well as extensive experimental results to validate the model's performance.\",\n \"Questions\": [\n \"Could the authors provide more detailed comparisons with other recent models not included in Table 2?\",\n \"What is the impact of varying the number of layers (N) in both the encoder and decoder stacks?\",\n \"Can the authors provide more insights into the choice of hyperparameters, especially the learning rate schedule and warmup steps?\"\n ],\n \"Limitations\": [\n \"The paper does not explore the application of the Transformer to tasks beyond machine translation, such as image or audio processing.\",\n \"The discussion on the potential negative societal impacts of the model is minimal and could be expanded.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 4,\n \"Presentation\": 3,\n \"Contribution\": 4,\n \"Overall\": 8,\n \"Confidence\": 5,\n \"Strengths\": [\n \"The Transformer model introduces a highly innovative use of self-attention mechanisms, replacing traditional recurrent and convolutional layers.\",\n \"Comprehensive experimental validation showing state-of-the-art performance in machine translation tasks.\",\n \"Clear and detailed description of the model architecture and its components, facilitating reproducibility and further research.\"\n ],\n \"Weaknesses\": [\n \"Limited discussion on the application of the model to other domains beyond machine translation.\",\n \"The paper could benefit from a deeper analysis of the potential negative societal impacts of the model.\"\n ],\n \"Originality\": 4,\n \"Quality\": 4,\n \"Clarity\": 4,\n \"Significance\": 4,\n \"Decision\": \"Accept\"\n}"
|
3 |
+
}
|
ai_scientist/fewshot_examples/attention.pdf
ADDED
Binary file (569 kB). View file
|
|
ai_scientist/fewshot_examples/attention.txt
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attention Is All You Need
|
2 |
+
|
3 |
+
|
4 |
+
**Ashish Vaswani[∗]**
|
5 |
+
Google Brain
|
6 |
+
```
|
7 |
+
avaswani@google.com
|
8 |
+
|
9 |
+
```
|
10 |
+
**Llion Jones[∗]**
|
11 |
+
Google Research
|
12 |
+
```
|
13 |
+
llion@google.com
|
14 |
+
|
15 |
+
```
|
16 |
+
|
17 |
+
**Noam Shazeer[∗]**
|
18 |
+
Google Brain
|
19 |
+
```
|
20 |
+
noam@google.com
|
21 |
+
|
22 |
+
```
|
23 |
+
|
24 |
+
**Niki Parmar[∗]**
|
25 |
+
Google Research
|
26 |
+
```
|
27 |
+
nikip@google.com
|
28 |
+
|
29 |
+
```
|
30 |
+
|
31 |
+
**Jakob Uszkoreit[∗]**
|
32 |
+
Google Research
|
33 |
+
```
|
34 |
+
usz@google.com
|
35 |
+
|
36 |
+
```
|
37 |
+
|
38 |
+
**Aidan N. Gomez[∗†]**
|
39 |
+
University of Toronto
|
40 |
+
```
|
41 |
+
aidan@cs.toronto.edu
|
42 |
+
|
43 |
+
```
|
44 |
+
|
45 |
+
**Łukasz Kaiser[∗]**
|
46 |
+
Google Brain
|
47 |
+
```
|
48 |
+
lukaszkaiser@google.com
|
49 |
+
|
50 |
+
```
|
51 |
+
|
52 |
+
**Illia Polosukhin[∗‡]**
|
53 |
+
```
|
54 |
+
illia.polosukhin@gmail.com
|
55 |
+
|
56 |
+
```
|
57 |
+
**Abstract**
|
58 |
+
|
59 |
+
The dominant sequence transduction models are based on complex recurrent or
|
60 |
+
convolutional neural networks that include an encoder and a decoder. The best
|
61 |
+
performing models also connect the encoder and decoder through an attention
|
62 |
+
mechanism. We propose a new simple network architecture, the Transformer,
|
63 |
+
based solely on attention mechanisms, dispensing with recurrence and convolutions
|
64 |
+
entirely. Experiments on two machine translation tasks show these models to
|
65 |
+
be superior in quality while being more parallelizable and requiring significantly
|
66 |
+
less time to train. Our model achieves 28.4 BLEU on the WMT 2014 Englishto-German translation task, improving over the existing best results, including
|
67 |
+
ensembles, by over 2 BLEU. On the WMT 2014 English-to-French translation task,
|
68 |
+
our model establishes a new single-model state-of-the-art BLEU score of 41.0 after
|
69 |
+
training for 3.5 days on eight GPUs, a small fraction of the training costs of the
|
70 |
+
best models from the literature.
|
71 |
+
|
72 |
+
**1** **Introduction**
|
73 |
+
|
74 |
+
Recurrent neural networks, long short-term memory [12] and gated recurrent [7] neural networks
|
75 |
+
in particular, have been firmly established as state of the art approaches in sequence modeling and
|
76 |
+
transduction problems such as language modeling and machine translation [29, 2, 5]. Numerous
|
77 |
+
efforts have since continued to push the boundaries of recurrent language models and encoder-decoder
|
78 |
+
architectures [31, 21, 13].
|
79 |
+
|
80 |
+
_∗Equal contribution. Listing order is random. Jakob proposed replacing RNNs with self-attention and started_
|
81 |
+
the effort to evaluate this idea. Ashish, with Illia, designed and implemented the first Transformer models and
|
82 |
+
has been crucially involved in every aspect of this work. Noam proposed scaled dot-product attention, multi-head
|
83 |
+
attention and the parameter-free position representation and became the other person involved in nearly every
|
84 |
+
detail. Niki designed, implemented, tuned and evaluated countless model variants in our original codebase and
|
85 |
+
tensor2tensor. Llion also experimented with novel model variants, was responsible for our initial codebase, and
|
86 |
+
efficient inference and visualizations. Lukasz and Aidan spent countless long days designing various parts of and
|
87 |
+
implementing tensor2tensor, replacing our earlier codebase, greatly improving results and massively accelerating
|
88 |
+
our research.
|
89 |
+
_†Work performed while at Google Brain._
|
90 |
+
_‡Work performed while at Google Research._
|
91 |
+
|
92 |
+
31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.
|
93 |
+
|
94 |
+
|
95 |
+
-----
|
96 |
+
|
97 |
+
Recurrent models typically factor computation along the symbol positions of the input and output
|
98 |
+
sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden
|
99 |
+
states ht, as a function of the previous hidden state ht 1 and the input for position t. This inherently
|
100 |
+
_−_
|
101 |
+
sequential nature precludes parallelization within training examples, which becomes critical at longer
|
102 |
+
sequence lengths, as memory constraints limit batching across examples. Recent work has achieved
|
103 |
+
significant improvements in computational efficiency through factorization tricks [18] and conditional
|
104 |
+
computation [26], while also improving model performance in case of the latter. The fundamental
|
105 |
+
constraint of sequential computation, however, remains.
|
106 |
+
|
107 |
+
Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in
|
108 |
+
the input or output sequences [2, 16]. In all but a few cases [22], however, such attention mechanisms
|
109 |
+
are used in conjunction with a recurrent network.
|
110 |
+
|
111 |
+
In this work we propose the Transformer, a model architecture eschewing recurrence and instead
|
112 |
+
relying entirely on an attention mechanism to draw global dependencies between input and output.
|
113 |
+
The Transformer allows for significantly more parallelization and can reach a new state of the art in
|
114 |
+
translation quality after being trained for as little as twelve hours on eight P100 GPUs.
|
115 |
+
|
116 |
+
**2** **Background**
|
117 |
+
|
118 |
+
The goal of reducing sequential computation also forms the foundation of the Extended Neural GPU
|
119 |
+
|
120 |
+
[20], ByteNet [15] and ConvS2S [8], all of which use convolutional neural networks as basic building
|
121 |
+
block, computing hidden representations in parallel for all input and output positions. In these models,
|
122 |
+
the number of operations required to relate signals from two arbitrary input or output positions grows
|
123 |
+
in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes
|
124 |
+
it more difficult to learn dependencies between distant positions [11]. In the Transformer this is
|
125 |
+
reduced to a constant number of operations, albeit at the cost of reduced effective resolution due
|
126 |
+
to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as
|
127 |
+
described in section 3.2.
|
128 |
+
|
129 |
+
Self-attention, sometimes called intra-attention is an attention mechanism relating different positions
|
130 |
+
of a single sequence in order to compute a representation of the sequence. Self-attention has been
|
131 |
+
used successfully in a variety of tasks including reading comprehension, abstractive summarization,
|
132 |
+
textual entailment and learning task-independent sentence representations [4, 22, 23, 19].
|
133 |
+
|
134 |
+
End-to-end memory networks are based on a recurrent attention mechanism instead of sequencealigned recurrence and have been shown to perform well on simple-language question answering and
|
135 |
+
language modeling tasks [28].
|
136 |
+
|
137 |
+
To the best of our knowledge, however, the Transformer is the first transduction model relying
|
138 |
+
entirely on self-attention to compute representations of its input and output without using sequencealigned RNNs or convolution. In the following sections, we will describe the Transformer, motivate
|
139 |
+
self-attention and discuss its advantages over models such as [14, 15] and [8].
|
140 |
+
|
141 |
+
**3** **Model Architecture**
|
142 |
+
|
143 |
+
Most competitive neural sequence transduction models have an encoder-decoder structure [5, 2, 29].
|
144 |
+
Here, the encoder maps an input sequence of symbol representations (x1, ..., xn) to a sequence
|
145 |
+
of continuous representations z = (z1, ..., zn). Given z, the decoder then generates an output
|
146 |
+
sequence (y1, ..., ym) of symbols one element at a time. At each step the model is auto-regressive
|
147 |
+
|
148 |
+
[9], consuming the previously generated symbols as additional input when generating the next.
|
149 |
+
|
150 |
+
The Transformer follows this overall architecture using stacked self-attention and point-wise, fully
|
151 |
+
connected layers for both the encoder and decoder, shown in the left and right halves of Figure 1,
|
152 |
+
respectively.
|
153 |
+
|
154 |
+
**3.1** **Encoder and Decoder Stacks**
|
155 |
+
|
156 |
+
**Encoder:** The encoder is composed of a stack of N = 6 identical layers. Each layer has two
|
157 |
+
sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position
|
158 |
+
|
159 |
+
-----
|
160 |
+
|
161 |
+
Figure 1: The Transformer - model architecture.
|
162 |
+
|
163 |
+
wise fully connected feed-forward network. We employ a residual connection [10] around each of
|
164 |
+
the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is
|
165 |
+
LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer
|
166 |
+
itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding
|
167 |
+
layers, produce outputs of dimension dmodel = 512.
|
168 |
+
|
169 |
+
**Decoder:** The decoder is also composed of a stack of N = 6 identical layers. In addition to the two
|
170 |
+
sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head
|
171 |
+
attention over the output of the encoder stack. Similar to the encoder, we employ residual connections
|
172 |
+
around each of the sub-layers, followed by layer normalization. We also modify the self-attention
|
173 |
+
sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This
|
174 |
+
masking, combined with fact that the output embeddings are offset by one position, ensures that the
|
175 |
+
predictions for position i can depend only on the known outputs at positions less than i.
|
176 |
+
|
177 |
+
**3.2** **Attention**
|
178 |
+
|
179 |
+
An attention function can be described as mapping a query and a set of key-value pairs to an output,
|
180 |
+
where the query, keys, values, and output are all vectors. The output is computed as a weighted sum
|
181 |
+
of the values, where the weight assigned to each value is computed by a compatibility function of the
|
182 |
+
query with the corresponding key.
|
183 |
+
|
184 |
+
**3.2.1** **Scaled Dot-Product Attention**
|
185 |
+
|
186 |
+
We call our particular attention "Scaled Dot-Product Attention" (Figure 2). The input consists of
|
187 |
+
queries and keys of dimension dk, and values of dimension dv. We compute the dot products of the
|
188 |
+
|
189 |
+
|
190 |
+
-----
|
191 |
+
|
192 |
+
Scaled Dot-Product Attention Multi-Head Attention
|
193 |
+
|
194 |
+
Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several
|
195 |
+
attention layers running in parallel.
|
196 |
+
|
197 |
+
query with all keys, divide each by _dk, and apply a softmax function to obtain the weights on the_
|
198 |
+
|
199 |
+
_[√]_
|
200 |
+
values.
|
201 |
+
|
202 |
+
In practice, we compute the attention function on a set of queries simultaneously, packed together
|
203 |
+
into a matrix Q. The keys and values are also packed together into matrices K and V . We compute
|
204 |
+
the matrix of outputs as:
|
205 |
+
|
206 |
+
Attention(Q, K, V ) = softmax( _[QK]√dk[T]_ )V (1)
|
207 |
+
|
208 |
+
The two most commonly used attention functions are additive attention [2], and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor
|
209 |
+
of _√1dk . Additive attention computes the compatibility function using a feed-forward network with_
|
210 |
+
a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is
|
211 |
+
much faster and more space-efficient in practice, since it can be implemented using highly optimized
|
212 |
+
matrix multiplication code.
|
213 |
+
|
214 |
+
While for small values of dk the two mechanisms perform similarly, additive attention outperforms
|
215 |
+
dot product attention without scaling for larger values of dk [3]. We suspect that for large values of
|
216 |
+
_dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has_
|
217 |
+
extremely small gradients [4]. To counteract this effect, we scale the dot products by _√1dk ._
|
218 |
+
|
219 |
+
**3.2.2** **Multi-Head Attention**
|
220 |
+
|
221 |
+
Instead of performing a single attention function with dmodel-dimensional keys, values and queries,
|
222 |
+
we found it beneficial to linearly project the queries, keys and values h times with different, learned
|
223 |
+
linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of
|
224 |
+
queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional
|
225 |
+
output values. These are concatenated and once again projected, resulting in the final values, as
|
226 |
+
depicted in Figure 2.
|
227 |
+
|
228 |
+
Multi-head attention allows the model to jointly attend to information from different representation
|
229 |
+
subspaces at different positions. With a single attention head, averaging inhibits this.
|
230 |
+
|
231 |
+
4To illustrate why the dot products get large, assume that the components of q and k are independent random
|
232 |
+
variables with mean 0 and variance 1. Then their dot product, q · k = _i=1_ _[q][i][k][i][, has mean][ 0][ and variance][ d][k][.]_
|
233 |
+
|
234 |
+
[P][d][k]
|
235 |
+
|
236 |
+
|
237 |
+
-----
|
238 |
+
|
239 |
+
MultiHead(Q, K, V ) = Concat(head1, ..., headh)W _[O]_
|
240 |
+
|
241 |
+
where headi = Attention(QWi[Q][, KW][ K]i _[, V W][ V]i_ [)]
|
242 |
+
|
243 |
+
Where the projections are parameter matrices Wi[Q] R[d][model][×][d][k], Wi[K] R[d][model][×][d][k], Wi[V] R[d][model][×][d][v]
|
244 |
+
_∈_ _∈_ _∈_
|
245 |
+
and W _[O]_ _∈_ R[hd][v][×][d][model].
|
246 |
+
|
247 |
+
In this work we employ h = 8 parallel attention layers, or heads. For each of these we use
|
248 |
+
_dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost_
|
249 |
+
is similar to that of single-head attention with full dimensionality.
|
250 |
+
|
251 |
+
**3.2.3** **Applications of Attention in our Model**
|
252 |
+
|
253 |
+
The Transformer uses multi-head attention in three different ways:
|
254 |
+
|
255 |
+
_• In "encoder-decoder attention" layers, the queries come from the previous decoder layer,_
|
256 |
+
and the memory keys and values come from the output of the encoder. This allows every
|
257 |
+
position in the decoder to attend over all positions in the input sequence. This mimics the
|
258 |
+
typical encoder-decoder attention mechanisms in sequence-to-sequence models such as
|
259 |
+
|
260 |
+
[31, 2, 8].
|
261 |
+
|
262 |
+
_• The encoder contains self-attention layers. In a self-attention layer all of the keys, values_
|
263 |
+
and queries come from the same place, in this case, the output of the previous layer in the
|
264 |
+
encoder. Each position in the encoder can attend to all positions in the previous layer of the
|
265 |
+
encoder.
|
266 |
+
|
267 |
+
_• Similarly, self-attention layers in the decoder allow each position in the decoder to attend to_
|
268 |
+
all positions in the decoder up to and including that position. We need to prevent leftward
|
269 |
+
information flow in the decoder to preserve the auto-regressive property. We implement this
|
270 |
+
inside of scaled dot-product attention by masking out (setting to −∞) all values in the input
|
271 |
+
of the softmax which correspond to illegal connections. See Figure 2.
|
272 |
+
|
273 |
+
**3.3** **Position-wise Feed-Forward Networks**
|
274 |
+
|
275 |
+
In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully
|
276 |
+
connected feed-forward network, which is applied to each position separately and identically. This
|
277 |
+
consists of two linear transformations with a ReLU activation in between.
|
278 |
+
|
279 |
+
FFN(x) = max(0, xW1 + b1)W2 + b2 (2)
|
280 |
+
|
281 |
+
While the linear transformations are the same across different positions, they use different parameters
|
282 |
+
from layer to layer. Another way of describing this is as two convolutions with kernel size 1.
|
283 |
+
The dimensionality of input and output is dmodel = 512, and the inner-layer has dimensionality
|
284 |
+
_dff = 2048._
|
285 |
+
|
286 |
+
**3.4** **Embeddings and Softmax**
|
287 |
+
|
288 |
+
Similarly to other sequence transduction models, we use learned embeddings to convert the input
|
289 |
+
tokens and output tokens to vectors of dimension dmodel. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In
|
290 |
+
our model, we share the same weight matrix between the two embedding layers and the pre-softmax
|
291 |
+
linear transformation, similar to [24]. In the embedding layers, we multiply those weights by _dmodel._
|
292 |
+
|
293 |
+
_[√]_
|
294 |
+
|
295 |
+
**3.5** **Positional Encoding**
|
296 |
+
|
297 |
+
Since our model contains no recurrence and no convolution, in order for the model to make use of the
|
298 |
+
order of the sequence, we must inject some information about the relative or absolute position of the
|
299 |
+
tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the
|
300 |
+
|
301 |
+
|
302 |
+
-----
|
303 |
+
|
304 |
+
Table 1: Maximum path lengths, per-layer complexity and minimum number of sequential operations
|
305 |
+
for different layer types. n is the sequence length, d is the representation dimension, k is the kernel
|
306 |
+
size of convolutions and r the size of the neighborhood in restricted self-attention.
|
307 |
+
|
308 |
+
Layer Type Complexity per Layer Sequential Maximum Path Length
|
309 |
+
Operations
|
310 |
+
|
311 |
+
Self-Attention _O(n[2]_ _· d)_ _O(1)_ _O(1)_
|
312 |
+
Recurrent _O(n · d[2])_ _O(n)_ _O(n)_
|
313 |
+
Convolutional _O(k_ _n_ _d[2])_ _O(1)_ _O(logk(n))_
|
314 |
+
_·_ _·_
|
315 |
+
Self-Attention (restricted) _O(r · n · d)_ _O(1)_ _O(n/r)_
|
316 |
+
|
317 |
+
bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel
|
318 |
+
as the embeddings, so that the two can be summed. There are many choices of positional encodings,
|
319 |
+
learned and fixed [8].
|
320 |
+
|
321 |
+
In this work, we use sine and cosine functions of different frequencies:
|
322 |
+
|
323 |
+
_PE(pos,2i) = sin(pos/10000[2][i/d][model])_
|
324 |
+
|
325 |
+
_PE(pos,2i+1) = cos(pos/10000[2][i/d][model])_
|
326 |
+
|
327 |
+
where pos is the position and i is the dimension. That is, each dimension of the positional encoding
|
328 |
+
corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We
|
329 |
+
chose this function because we hypothesized it would allow the model to easily learn to attend by
|
330 |
+
relative positions, since for any fixed offset k, PEpos+k can be represented as a linear function of
|
331 |
+
_PEpos._
|
332 |
+
|
333 |
+
We also experimented with using learned positional embeddings [8] instead, and found that the two
|
334 |
+
versions produced nearly identical results (see Table 3 row (E)). We chose the sinusoidal version
|
335 |
+
because it may allow the model to extrapolate to sequence lengths longer than the ones encountered
|
336 |
+
during training.
|
337 |
+
|
338 |
+
**4** **Why Self-Attention**
|
339 |
+
|
340 |
+
In this section we compare various aspects of self-attention layers to the recurrent and convolutional layers commonly used for mapping one variable-length sequence of symbol representations
|
341 |
+
(layer in a typical sequence transduction encoder or decoder. Motivating our use of self-attention wex1, ..., xn) to another sequence of equal length (z1, ..., zn), with xi, zi ∈ R[d], such as a hidden
|
342 |
+
consider three desiderata.
|
343 |
+
|
344 |
+
One is the total computational complexity per layer. Another is the amount of computation that can
|
345 |
+
be parallelized, as measured by the minimum number of sequential operations required.
|
346 |
+
|
347 |
+
The third is the path length between long-range dependencies in the network. Learning long-range
|
348 |
+
dependencies is a key challenge in many sequence transduction tasks. One key factor affecting the
|
349 |
+
ability to learn such dependencies is the length of the paths forward and backward signals have to
|
350 |
+
traverse in the network. The shorter these paths between any combination of positions in the input
|
351 |
+
and output sequences, the easier it is to learn long-range dependencies [11]. Hence we also compare
|
352 |
+
the maximum path length between any two input and output positions in networks composed of the
|
353 |
+
different layer types.
|
354 |
+
|
355 |
+
As noted in Table 1, a self-attention layer connects all positions with a constant number of sequentially
|
356 |
+
executed operations, whereas a recurrent layer requires O(n) sequential operations. In terms of
|
357 |
+
computational complexity, self-attention layers are faster than recurrent layers when the sequence
|
358 |
+
length n is smaller than the representation dimensionality d, which is most often the case with
|
359 |
+
sentence representations used by state-of-the-art models in machine translations, such as word-piece
|
360 |
+
|
361 |
+
[31] and byte-pair [25] representations. To improve computational performance for tasks involving
|
362 |
+
very long sequences, self-attention could be restricted to considering only a neighborhood of size r in
|
363 |
+
|
364 |
+
|
365 |
+
-----
|
366 |
+
|
367 |
+
the input sequence centered around the respective output position. This would increase the maximum
|
368 |
+
path length to O(n/r). We plan to investigate this approach further in future work.
|
369 |
+
|
370 |
+
A single convolutional layer with kernel width k < n does not connect all pairs of input and output
|
371 |
+
positions. Doing so requires a stack of O(n/k) convolutional layers in the case of contiguous kernels,
|
372 |
+
or O(logk(n)) in the case of dilated convolutions [15], increasing the length of the longest paths
|
373 |
+
between any two positions in the network. Convolutional layers are generally more expensive than
|
374 |
+
recurrent layers, by a factor of k. Separable convolutions [6], however, decrease the complexity
|
375 |
+
considerably, to O(k · n · d + n · d[2]). Even with k = n, however, the complexity of a separable
|
376 |
+
convolution is equal to the combination of a self-attention layer and a point-wise feed-forward layer,
|
377 |
+
the approach we take in our model.
|
378 |
+
|
379 |
+
As side benefit, self-attention could yield more interpretable models. We inspect attention distributions
|
380 |
+
from our models and present and discuss examples in the appendix. Not only do individual attention
|
381 |
+
heads clearly learn to perform different tasks, many appear to exhibit behavior related to the syntactic
|
382 |
+
and semantic structure of the sentences.
|
383 |
+
|
384 |
+
**5** **Training**
|
385 |
+
|
386 |
+
This section describes the training regime for our models.
|
387 |
+
|
388 |
+
**5.1** **Training Data and Batching**
|
389 |
+
|
390 |
+
We trained on the standard WMT 2014 English-German dataset consisting of about 4.5 million
|
391 |
+
sentence pairs. Sentences were encoded using byte-pair encoding [3], which has a shared sourcetarget vocabulary of about 37000 tokens. For English-French, we used the significantly larger WMT
|
392 |
+
2014 English-French dataset consisting of 36M sentences and split tokens into a 32000 word-piece
|
393 |
+
vocabulary [31]. Sentence pairs were batched together by approximate sequence length. Each training
|
394 |
+
batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000
|
395 |
+
target tokens.
|
396 |
+
|
397 |
+
**5.2** **Hardware and Schedule**
|
398 |
+
|
399 |
+
We trained our models on one machine with 8 NVIDIA P100 GPUs. For our base models using
|
400 |
+
the hyperparameters described throughout the paper, each training step took about 0.4 seconds. We
|
401 |
+
trained the base models for a total of 100,000 steps or 12 hours. For our big models,(described on the
|
402 |
+
bottom line of table 3), step time was 1.0 seconds. The big models were trained for 300,000 steps
|
403 |
+
(3.5 days).
|
404 |
+
|
405 |
+
**5.3** **Optimizer**
|
406 |
+
|
407 |
+
We used the Adam optimizer [17] with β1 = 0.9, β2 = 0.98 and ϵ = 10[−][9]. We varied the learning
|
408 |
+
rate over the course of training, according to the formula:
|
409 |
+
|
410 |
+
_lrate = d[−]model[0][.][5]_ (3)
|
411 |
+
|
412 |
+
_[·][ min(][step][_][num][−][0][.][5][, step][_][num][ ·][ warmup][_][steps][−][1][.][5][)]_
|
413 |
+
|
414 |
+
This corresponds to increasing the learning rate linearly for the first warmup_steps training steps,
|
415 |
+
and decreasing it thereafter proportionally to the inverse square root of the step number. We used
|
416 |
+
_warmup_steps = 4000._
|
417 |
+
|
418 |
+
**5.4** **Regularization**
|
419 |
+
|
420 |
+
We employ three types of regularization during training:
|
421 |
+
|
422 |
+
**Residual Dropout** We apply dropout [27] to the output of each sub-layer, before it is added to the
|
423 |
+
sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the
|
424 |
+
positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of
|
425 |
+
_Pdrop = 0.1._
|
426 |
+
|
427 |
+
|
428 |
+
-----
|
429 |
+
|
430 |
+
Table 2: The Transformer achieves better BLEU scores than previous state-of-the-art models on the
|
431 |
+
English-to-German and English-to-French newstest2014 tests at a fraction of the training cost.
|
432 |
+
|
433 |
+
BLEU Training Cost (FLOPs)
|
434 |
+
Model
|
435 |
+
|
436 |
+
EN-DE EN-FR EN-DE EN-FR
|
437 |
+
|
438 |
+
ByteNet [15] 23.75
|
439 |
+
Deep-Att + PosUnk [32] 39.2 1.0 · 10[20]
|
440 |
+
|
441 |
+
GNMT + RL [31] 24.6 39.92 2.3 · 10[19] 1.4 · 10[20]
|
442 |
+
|
443 |
+
ConvS2S [8] 25.16 40.46 9.6 · 10[18] 1.5 · 10[20]
|
444 |
+
|
445 |
+
MoE [26] 26.03 40.56 2.0 · 10[19] 1.2 · 10[20]
|
446 |
+
|
447 |
+
Deep-Att + PosUnk Ensemble [32] 40.4 8.0 · 10[20]
|
448 |
+
|
449 |
+
GNMT + RL Ensemble [31] 26.30 41.16 1.8 · 10[20] 1.1 · 10[21]
|
450 |
+
|
451 |
+
ConvS2S Ensemble [8] 26.36 **41.29** 7.7 · 10[19] 1.2 · 10[21]
|
452 |
+
|
453 |
+
Transformer (base model) 27.3 38.1 **3.3 · 10[18]**
|
454 |
+
|
455 |
+
Transformer (big) **28.4** **41.0** 2.3 · 10[19]
|
456 |
+
|
457 |
+
|
458 |
+
**Label Smoothing** During training, we employed label smoothing of value ϵls = 0.1 [30]. This
|
459 |
+
hurts perplexity, as the model learns to be more unsure, but improves accuracy and BLEU score.
|
460 |
+
|
461 |
+
**6** **Results**
|
462 |
+
|
463 |
+
**6.1** **Machine Translation**
|
464 |
+
|
465 |
+
On the WMT 2014 English-to-German translation task, the big transformer model (Transformer (big)
|
466 |
+
in Table 2) outperforms the best previously reported models (including ensembles) by more than 2.0
|
467 |
+
BLEU, establishing a new state-of-the-art BLEU score of 28.4. The configuration of this model is
|
468 |
+
listed in the bottom line of Table 3. Training took 3.5 days on 8 P100 GPUs. Even our base model
|
469 |
+
surpasses all previously published models and ensembles, at a fraction of the training cost of any of
|
470 |
+
the competitive models.
|
471 |
+
|
472 |
+
On the WMT 2014 English-to-French translation task, our big model achieves a BLEU score of 41.0,
|
473 |
+
outperforming all of the previously published single models, at less than 1/4 the training cost of the
|
474 |
+
previous state-of-the-art model. The Transformer (big) model trained for English-to-French used
|
475 |
+
dropout rate Pdrop = 0.1, instead of 0.3.
|
476 |
+
|
477 |
+
For the base models, we used a single model obtained by averaging the last 5 checkpoints, which
|
478 |
+
were written at 10-minute intervals. For the big models, we averaged the last 20 checkpoints. We
|
479 |
+
used beam search with a beam size of 4 and length penalty α = 0.6 [31]. These hyperparameters
|
480 |
+
were chosen after experimentation on the development set. We set the maximum output length during
|
481 |
+
inference to input length + 50, but terminate early when possible [31].
|
482 |
+
|
483 |
+
Table 2 summarizes our results and compares our translation quality and training costs to other model
|
484 |
+
architectures from the literature. We estimate the number of floating point operations used to train a
|
485 |
+
model by multiplying the training time, the number of GPUs used, and an estimate of the sustained
|
486 |
+
single-precision floating-point capacity of each GPU [5].
|
487 |
+
|
488 |
+
**6.2** **Model Variations**
|
489 |
+
|
490 |
+
To evaluate the importance of different components of the Transformer, we varied our base model
|
491 |
+
in different ways, measuring the change in performance on English-to-German translation on the
|
492 |
+
development set, newstest2013. We used beam search as described in the previous section, but no
|
493 |
+
checkpoint averaging. We present these results in Table 3.
|
494 |
+
|
495 |
+
In Table 3 rows (A), we vary the number of attention heads and the attention key and value dimensions,
|
496 |
+
keeping the amount of computation constant, as described in Section 3.2.2. While single-head
|
497 |
+
attention is 0.9 BLEU worse than the best setting, quality also drops off with too many heads.
|
498 |
+
|
499 |
+
5We used values of 2.8, 3.7, 6.0 and 9.5 TFLOPS for K80, K40, M40 and P100, respectively.
|
500 |
+
|
501 |
+
|
502 |
+
-----
|
503 |
+
|
504 |
+
Table 3: Variations on the Transformer architecture. Unlisted values are identical to those of the base
|
505 |
+
model. All metrics are on the English-to-German translation development set, newstest2013. Listed
|
506 |
+
perplexities are per-wordpiece, according to our byte-pair encoding, and should not be compared to
|
507 |
+
per-word perplexities.
|
508 |
+
|
509 |
+
|Col1|train N d d h d d P ϵ model ff k v drop ls steps|PPL BLEU params (dev) (dev) 106 ×|
|
510 |
+
|---|---|---|
|
511 |
+
|base|6 512 2048 8 64 64 0.1 0.1 100K|4.92 25.8 65|
|
512 |
+
|(A)|1 512 512 4 128 128 16 32 32 32 16 16|5.29 24.9 5.00 25.5 4.91 25.8 5.01 25.4|
|
513 |
+
|(B)|16 32|5.16 25.1 58 5.01 25.4 60|
|
514 |
+
|(C)|2 4 8 256 32 32 1024 128 128 1024 4096|6.11 23.7 36 5.19 25.3 50 4.88 25.5 80 5.75 24.5 28 4.66 26.0 168 5.12 25.4 53 4.75 26.2 90|
|
515 |
+
|(D)|0.0 0.2 0.0 0.2|5.77 24.6 4.95 25.5 4.67 25.3 5.47 25.7|
|
516 |
+
|(E)|positional embedding instead of sinusoids|4.92 25.7|
|
517 |
+
|big|6 1024 4096 16 0.3 300K|4.33 26.4 213|
|
518 |
+
|
519 |
+
|
520 |
+
|
521 |
+
In Table 3 rows (B), we observe that reducing the attention key size dk hurts model quality. This
|
522 |
+
suggests that determining compatibility is not easy and that a more sophisticated compatibility
|
523 |
+
function than dot product may be beneficial. We further observe in rows (C) and (D) that, as expected,
|
524 |
+
bigger models are better, and dropout is very helpful in avoiding over-fitting. In row (E) we replace our
|
525 |
+
sinusoidal positional encoding with learned positional embeddings [8], and observe nearly identical
|
526 |
+
results to the base model.
|
527 |
+
|
528 |
+
**7** **Conclusion**
|
529 |
+
|
530 |
+
In this work, we presented the Transformer, the first sequence transduction model based entirely on
|
531 |
+
attention, replacing the recurrent layers most commonly used in encoder-decoder architectures with
|
532 |
+
multi-headed self-attention.
|
533 |
+
|
534 |
+
For translation tasks, the Transformer can be trained significantly faster than architectures based
|
535 |
+
on recurrent or convolutional layers. On both WMT 2014 English-to-German and WMT 2014
|
536 |
+
English-to-French translation tasks, we achieve a new state of the art. In the former task our best
|
537 |
+
model outperforms even all previously reported ensembles.
|
538 |
+
|
539 |
+
We are excited about the future of attention-based models and plan to apply them to other tasks. We
|
540 |
+
plan to extend the Transformer to problems involving input and output modalities other than text and
|
541 |
+
to investigate local, restricted attention mechanisms to efficiently handle large inputs and outputs
|
542 |
+
such as images, audio and video. Making generation less sequential is another research goals of ours.
|
543 |
+
|
544 |
+
[The code we used to train and evaluate our models is available at https://github.com/](https://github.com/tensorflow/tensor2tensor)
|
545 |
+
```
|
546 |
+
tensorflow/tensor2tensor.
|
547 |
+
|
548 |
+
```
|
549 |
+
**Acknowledgements** We are grateful to Nal Kalchbrenner and Stephan Gouws for their fruitful
|
550 |
+
comments, corrections and inspiration.
|
551 |
+
|
552 |
+
|
553 |
+
-----
|
554 |
+
|
555 |
+
**References**
|
556 |
+
|
557 |
+
[1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint
|
558 |
+
_arXiv:1607.06450, 2016._
|
559 |
+
|
560 |
+
[2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly
|
561 |
+
learning to align and translate. CoRR, abs/1409.0473, 2014.
|
562 |
+
|
563 |
+
[3] Denny Britz, Anna Goldie, Minh-Thang Luong, and Quoc V. Le. Massive exploration of neural
|
564 |
+
machine translation architectures. CoRR, abs/1703.03906, 2017.
|
565 |
+
|
566 |
+
[4] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine
|
567 |
+
reading. arXiv preprint arXiv:1601.06733, 2016.
|
568 |
+
|
569 |
+
[5] Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Fethi Bougares, Holger Schwenk,
|
570 |
+
and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical
|
571 |
+
machine translation. CoRR, abs/1406.1078, 2014.
|
572 |
+
|
573 |
+
[6] Francois Chollet. Xception: Deep learning with depthwise separable convolutions. arXiv
|
574 |
+
_preprint arXiv:1610.02357, 2016._
|
575 |
+
|
576 |
+
[7] Junyoung Chung, Çaglar Gülçehre, Kyunghyun Cho, and Yoshua Bengio. Empirical evaluation
|
577 |
+
of gated recurrent neural networks on sequence modeling. CoRR, abs/1412.3555, 2014.
|
578 |
+
|
579 |
+
[8] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin. Convolutional sequence to sequence learning. arXiv preprint arXiv:1705.03122v2, 2017.
|
580 |
+
|
581 |
+
[9] Alex Graves. Generating sequences with recurrent neural networks. _arXiv preprint_
|
582 |
+
_arXiv:1308.0850, 2013._
|
583 |
+
|
584 |
+
[10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern
|
585 |
+
_Recognition, pages 770–778, 2016._
|
586 |
+
|
587 |
+
[11] Sepp Hochreiter, Yoshua Bengio, Paolo Frasconi, and Jürgen Schmidhuber. Gradient flow in
|
588 |
+
recurrent nets: the difficulty of learning long-term dependencies, 2001.
|
589 |
+
|
590 |
+
[12] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation,
|
591 |
+
9(8):1735–1780, 1997.
|
592 |
+
|
593 |
+
[13] Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, and Yonghui Wu. Exploring
|
594 |
+
the limits of language modeling. arXiv preprint arXiv:1602.02410, 2016.
|
595 |
+
|
596 |
+
[14] Łukasz Kaiser and Ilya Sutskever. Neural GPUs learn algorithms. In International Conference
|
597 |
+
_on Learning Representations (ICLR), 2016._
|
598 |
+
|
599 |
+
[15] Nal Kalchbrenner, Lasse Espeholt, Karen Simonyan, Aaron van den Oord, Alex Graves, and Koray Kavukcuoglu. Neural machine translation in linear time. arXiv preprint arXiv:1610.10099v2,
|
600 |
+
2017.
|
601 |
+
|
602 |
+
[16] Yoon Kim, Carl Denton, Luong Hoang, and Alexander M. Rush. Structured attention networks.
|
603 |
+
In International Conference on Learning Representations, 2017.
|
604 |
+
|
605 |
+
[17] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
|
606 |
+
|
607 |
+
[18] Oleksii Kuchaiev and Boris Ginsburg. Factorization tricks for LSTM networks. arXiv preprint
|
608 |
+
_arXiv:1703.10722, 2017._
|
609 |
+
|
610 |
+
[19] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen
|
611 |
+
Zhou, and Yoshua Bengio. A structured self-attentive sentence embedding. arXiv preprint
|
612 |
+
_arXiv:1703.03130, 2017._
|
613 |
+
|
614 |
+
[20] Samy Bengio Łukasz Kaiser. Can active memory replace attention? In Advances in Neural
|
615 |
+
_Information Processing Systems, (NIPS), 2016._
|
616 |
+
|
617 |
+
|
618 |
+
-----
|
619 |
+
|
620 |
+
[21] Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attentionbased neural machine translation. arXiv preprint arXiv:1508.04025, 2015.
|
621 |
+
|
622 |
+
[22] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention
|
623 |
+
model. In Empirical Methods in Natural Language Processing, 2016.
|
624 |
+
|
625 |
+
[23] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive
|
626 |
+
summarization. arXiv preprint arXiv:1705.04304, 2017.
|
627 |
+
|
628 |
+
[24] Ofir Press and Lior Wolf. Using the output embedding to improve language models. arXiv
|
629 |
+
_preprint arXiv:1608.05859, 2016._
|
630 |
+
|
631 |
+
[25] Rico Sennrich, Barry Haddow, and Alexandra Birch. Neural machine translation of rare words
|
632 |
+
with subword units. arXiv preprint arXiv:1508.07909, 2015.
|
633 |
+
|
634 |
+
[26] Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton,
|
635 |
+
and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts
|
636 |
+
layer. arXiv preprint arXiv:1701.06538, 2017.
|
637 |
+
|
638 |
+
[27] Nitish Srivastava, Geoffrey E Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine
|
639 |
+
_Learning Research, 15(1):1929–1958, 2014._
|
640 |
+
|
641 |
+
[28] Sainbayar Sukhbaatar, arthur szlam, Jason Weston, and Rob Fergus. End-to-end memory
|
642 |
+
networks. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors,
|
643 |
+
_Advances in Neural Information Processing Systems 28, pages 2440–2448. Curran Associates,_
|
644 |
+
Inc., 2015.
|
645 |
+
|
646 |
+
[29] Ilya Sutskever, Oriol Vinyals, and Quoc VV Le. Sequence to sequence learning with neural
|
647 |
+
networks. In Advances in Neural Information Processing Systems, pages 3104–3112, 2014.
|
648 |
+
|
649 |
+
[30] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna.
|
650 |
+
Rethinking the inception architecture for computer vision. CoRR, abs/1512.00567, 2015.
|
651 |
+
|
652 |
+
[31] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang
|
653 |
+
Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine
|
654 |
+
translation system: Bridging the gap between human and machine translation. arXiv preprint
|
655 |
+
_arXiv:1609.08144, 2016._
|
656 |
+
|
657 |
+
[32] Jie Zhou, Ying Cao, Xuguang Wang, Peng Li, and Wei Xu. Deep recurrent models with
|
658 |
+
fast-forward connections for neural machine translation. CoRR, abs/1606.04199, 2016.
|
659 |
+
|
660 |
+
|
661 |
+
-----
|
662 |
+
|
ai_scientist/generate_ideas.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import time
|
5 |
+
from typing import Dict, List, Union
|
6 |
+
|
7 |
+
import backoff
|
8 |
+
import requests
|
9 |
+
from strictjson import strict_json
|
10 |
+
|
11 |
+
from ai_scientist.llm import (
|
12 |
+
allchoices,
|
13 |
+
extract_json_between_markers,
|
14 |
+
get_response_from_llm,
|
15 |
+
llm_json_auto_correct,
|
16 |
+
)
|
17 |
+
|
18 |
+
S2_API_KEY = os.getenv("S2_API_KEY")
|
19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
20 |
+
|
21 |
+
|
22 |
+
idea_first_prompt = """{task_description}
|
23 |
+
<experiment.py>
|
24 |
+
{code}
|
25 |
+
</experiment.py>
|
26 |
+
|
27 |
+
Here are the ideas that you have already generated:
|
28 |
+
|
29 |
+
'''
|
30 |
+
{prev_ideas_string}
|
31 |
+
'''
|
32 |
+
|
33 |
+
Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
|
34 |
+
Note that you will not have access to any additional resources or datasets.
|
35 |
+
Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
|
36 |
+
|
37 |
+
Respond in the following format:
|
38 |
+
|
39 |
+
THOUGHT:
|
40 |
+
<THOUGHT>
|
41 |
+
|
42 |
+
NEW IDEA JSON:
|
43 |
+
```json
|
44 |
+
<JSON>
|
45 |
+
```
|
46 |
+
|
47 |
+
In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
|
48 |
+
|
49 |
+
Add '```json' before the <JSON> and '```' after the <JSON> as above. In <JSON>, provide the new idea in JSON format with the following keys and values:
|
50 |
+
- "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
|
51 |
+
- "Title": A title for the idea, will be used for the report writing.
|
52 |
+
- "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
|
53 |
+
- "Interestingness": A rating from 1 to 10 (lowest to highest).
|
54 |
+
- "Feasibility": A rating from 1 to 10 (lowest to highest).
|
55 |
+
- "Novelty": A rating from 1 to 10 (lowest to highest).
|
56 |
+
|
57 |
+
Be cautious and realistic on your ratings.
|
58 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
59 |
+
You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
|
60 |
+
"""
|
61 |
+
|
62 |
+
idea_reflection_prompt = """Round {current_round}/{num_reflections}.
|
63 |
+
In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
|
64 |
+
Include any other factors that you think are important in evaluating the idea.
|
65 |
+
Ensure the idea is clear and concise, and the JSON is the correct format.
|
66 |
+
Do not make things overly complicated.
|
67 |
+
In the next attempt, try and refine and improve your idea.
|
68 |
+
Stick to the spirit of the original idea unless there are glaring issues.
|
69 |
+
|
70 |
+
Respond in the exactly the same format as before:
|
71 |
+
THOUGHT:
|
72 |
+
<THOUGHT>
|
73 |
+
|
74 |
+
NEW IDEA JSON:
|
75 |
+
```json
|
76 |
+
<JSON>
|
77 |
+
```
|
78 |
+
|
79 |
+
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
|
80 |
+
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES.
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
# Format the content in JSON
|
85 |
+
def format_idea_json(text):
|
86 |
+
json_start_marker = "```json"
|
87 |
+
json_end_marker = "```"
|
88 |
+
start_index = text.find(json_start_marker)
|
89 |
+
if start_index != -1:
|
90 |
+
start_index += len(json_start_marker) # Move past the marker
|
91 |
+
end_index = text.find(json_end_marker, start_index)
|
92 |
+
json_string = text[start_index:end_index].strip()
|
93 |
+
res = strict_json(
|
94 |
+
system_prompt="You are a JSON formatter",
|
95 |
+
user_prompt=json_string,
|
96 |
+
output_format={
|
97 |
+
"Name": "A shortened descriptor of the idea",
|
98 |
+
"Title": "A title for the idea, will be used for the report writing",
|
99 |
+
"Experiment": "An outline of the implementation, type: list",
|
100 |
+
"Interestingness": "A rating from 1 to 10 (lowest to highest), type: int",
|
101 |
+
"Feasibility": "A rating from 1 to 10 (lowest to highest), type: int",
|
102 |
+
"Novelty": "A rating from 1 to 10 (lowest to highest), type: int",
|
103 |
+
},
|
104 |
+
llm=llm_json_auto_correct,
|
105 |
+
)
|
106 |
+
text = "```json\n" + json.dumps(res) + "```\n"
|
107 |
+
return text
|
108 |
+
|
109 |
+
|
110 |
+
def format_novelty_json(text):
|
111 |
+
json_start_marker = "```json"
|
112 |
+
json_end_marker = "```"
|
113 |
+
start_index = text.find(json_start_marker)
|
114 |
+
if start_index != -1:
|
115 |
+
start_index += len(json_start_marker) # Move past the marker
|
116 |
+
end_index = text.find(json_end_marker, start_index)
|
117 |
+
json_string = text[start_index:end_index].strip()
|
118 |
+
res = strict_json(
|
119 |
+
system_prompt="You are a JSON formatter",
|
120 |
+
user_prompt=json_string,
|
121 |
+
output_format={
|
122 |
+
"Query": "An optional search query to search the literature (e.g. attention is all you need)",
|
123 |
+
},
|
124 |
+
llm=llm_json_auto_correct,
|
125 |
+
)
|
126 |
+
text = "```json\n" + json.dumps(res) + "```\n"
|
127 |
+
return text
|
128 |
+
|
129 |
+
|
130 |
+
# GENERATE IDEAS
|
131 |
+
def generate_ideas(
|
132 |
+
base_dir,
|
133 |
+
client,
|
134 |
+
model,
|
135 |
+
skip_generation=False,
|
136 |
+
max_num_generations=20,
|
137 |
+
num_reflections=5,
|
138 |
+
):
|
139 |
+
if skip_generation:
|
140 |
+
# Load existing ideas from file
|
141 |
+
try:
|
142 |
+
with open(osp.join(base_dir, "ideas.json"), "r") as f:
|
143 |
+
ideas = json.load(f)
|
144 |
+
print("Loaded existing ideas:")
|
145 |
+
for idea in ideas:
|
146 |
+
print(idea)
|
147 |
+
return ideas
|
148 |
+
except FileNotFoundError:
|
149 |
+
print("No existing ideas found. Generating new ideas.")
|
150 |
+
except json.JSONDecodeError:
|
151 |
+
print("Error decoding existing ideas. Generating new ideas.")
|
152 |
+
|
153 |
+
idea_str_archive = []
|
154 |
+
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
|
155 |
+
seed_ideas = json.load(f)
|
156 |
+
for seed_idea in seed_ideas:
|
157 |
+
idea_str_archive.append(json.dumps(seed_idea))
|
158 |
+
|
159 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
160 |
+
code = f.read()
|
161 |
+
|
162 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
163 |
+
prompt = json.load(f)
|
164 |
+
|
165 |
+
idea_system_prompt = prompt["system"]
|
166 |
+
|
167 |
+
for _ in range(max_num_generations):
|
168 |
+
print()
|
169 |
+
print(f"Generating idea {_ + 1}/{max_num_generations}")
|
170 |
+
import traceback
|
171 |
+
try:
|
172 |
+
prev_ideas_string = "\n\n".join(idea_str_archive)
|
173 |
+
|
174 |
+
msg_history = []
|
175 |
+
print(f"Iteration 1/{num_reflections}")
|
176 |
+
text, msg_history = get_response_from_llm(
|
177 |
+
idea_first_prompt.format(
|
178 |
+
task_description=prompt["task_description"],
|
179 |
+
code=code,
|
180 |
+
prev_ideas_string=prev_ideas_string,
|
181 |
+
num_reflections=num_reflections,
|
182 |
+
),
|
183 |
+
client=client,
|
184 |
+
model=model,
|
185 |
+
system_message=idea_system_prompt,
|
186 |
+
msg_history=msg_history,
|
187 |
+
)
|
188 |
+
## Format the content in JSON
|
189 |
+
text = format_idea_json(text)
|
190 |
+
|
191 |
+
## PARSE OUTPUT
|
192 |
+
json_output = extract_json_between_markers(text)
|
193 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
194 |
+
# print(json_output)
|
195 |
+
|
196 |
+
# Iteratively improve task.
|
197 |
+
if num_reflections > 1:
|
198 |
+
for j in range(num_reflections - 1):
|
199 |
+
print(f"Iteration {j + 2}/{num_reflections}")
|
200 |
+
text, msg_history = get_response_from_llm(
|
201 |
+
idea_reflection_prompt.format(
|
202 |
+
current_round=j + 2, num_reflections=num_reflections
|
203 |
+
),
|
204 |
+
client=client,
|
205 |
+
model=model,
|
206 |
+
system_message=idea_system_prompt,
|
207 |
+
msg_history=msg_history,
|
208 |
+
)
|
209 |
+
## Format the content in JSON if using weak LLM
|
210 |
+
text = format_idea_json(text)
|
211 |
+
## PARSE OUTPUT
|
212 |
+
json_output = extract_json_between_markers(text)
|
213 |
+
assert (
|
214 |
+
json_output is not None
|
215 |
+
), "Failed to extract JSON from LLM output"
|
216 |
+
# print(json_output)
|
217 |
+
|
218 |
+
if "I am done" in text:
|
219 |
+
print(f"Idea generation converged after {j + 2} iterations.")
|
220 |
+
break
|
221 |
+
|
222 |
+
idea_str_archive.append(json.dumps(json_output))
|
223 |
+
except Exception as e:
|
224 |
+
print(f"Failed to generate idea: {e}")
|
225 |
+
traceback.print_exc()
|
226 |
+
continue
|
227 |
+
|
228 |
+
## SAVE IDEAS
|
229 |
+
ideas = []
|
230 |
+
for idea_str in idea_str_archive:
|
231 |
+
ideas.append(json.loads(idea_str))
|
232 |
+
|
233 |
+
with open(osp.join(base_dir, "ideas.json"), "w") as f:
|
234 |
+
json.dump(ideas, f, indent=4)
|
235 |
+
|
236 |
+
return ideas
|
237 |
+
|
238 |
+
|
239 |
+
# GENERATE IDEAS OPEN-ENDED
|
240 |
+
def generate_next_idea(
|
241 |
+
base_dir,
|
242 |
+
client,
|
243 |
+
model,
|
244 |
+
prev_idea_archive=[],
|
245 |
+
num_reflections=5,
|
246 |
+
max_attempts=10,
|
247 |
+
):
|
248 |
+
idea_archive = prev_idea_archive
|
249 |
+
original_archive_size = len(idea_archive)
|
250 |
+
|
251 |
+
print(f"Generating idea {original_archive_size + 1}")
|
252 |
+
|
253 |
+
if len(prev_idea_archive) == 0:
|
254 |
+
print(f"First iteration, taking seed ideas")
|
255 |
+
# seed the archive on the first run with pre-existing ideas
|
256 |
+
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
|
257 |
+
seed_ideas = json.load(f)
|
258 |
+
for seed_idea in seed_ideas[:1]:
|
259 |
+
idea_archive.append(seed_idea)
|
260 |
+
else:
|
261 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
262 |
+
code = f.read()
|
263 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
264 |
+
prompt = json.load(f)
|
265 |
+
idea_system_prompt = prompt["system"]
|
266 |
+
|
267 |
+
for _ in range(max_attempts):
|
268 |
+
import traceback
|
269 |
+
try:
|
270 |
+
idea_strings = []
|
271 |
+
for idea in idea_archive:
|
272 |
+
idea_strings.append(json.dumps(idea))
|
273 |
+
prev_ideas_string = "\n\n".join(idea_strings)
|
274 |
+
|
275 |
+
msg_history = []
|
276 |
+
print(f"Iteration 1/{num_reflections}")
|
277 |
+
text, msg_history = get_response_from_llm(
|
278 |
+
idea_first_prompt.format(
|
279 |
+
task_description=prompt["task_description"],
|
280 |
+
code=code,
|
281 |
+
prev_ideas_string=prev_ideas_string,
|
282 |
+
num_reflections=num_reflections,
|
283 |
+
)
|
284 |
+
+ """
|
285 |
+
Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
|
286 |
+
This is on a standard 1-10 ML conference scale.
|
287 |
+
Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
|
288 |
+
""",
|
289 |
+
client=client,
|
290 |
+
model=model,
|
291 |
+
system_message=idea_system_prompt,
|
292 |
+
msg_history=msg_history,
|
293 |
+
)
|
294 |
+
## Format the content in JSON if using weak LLM
|
295 |
+
text = format_idea_json(text)
|
296 |
+
## PARSE OUTPUT
|
297 |
+
json_output = extract_json_between_markers(text)
|
298 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
299 |
+
# print(json_output)
|
300 |
+
|
301 |
+
# Iteratively improve task.
|
302 |
+
if num_reflections > 1:
|
303 |
+
for j in range(num_reflections - 1):
|
304 |
+
print(f"Iteration {j + 2}/{num_reflections}")
|
305 |
+
text, msg_history = get_response_from_llm(
|
306 |
+
idea_reflection_prompt.format(
|
307 |
+
current_round=j + 2, num_reflections=num_reflections
|
308 |
+
),
|
309 |
+
client=client,
|
310 |
+
model=model,
|
311 |
+
system_message=idea_system_prompt,
|
312 |
+
msg_history=msg_history,
|
313 |
+
)
|
314 |
+
## Format the content in JSON if using weak LLM
|
315 |
+
text = format_idea_json(text)
|
316 |
+
## PARSE OUTPUT
|
317 |
+
json_output = extract_json_between_markers(text)
|
318 |
+
assert (
|
319 |
+
json_output is not None
|
320 |
+
), "Failed to extract JSON from LLM output"
|
321 |
+
# print(json_output)
|
322 |
+
|
323 |
+
if "I am done" in text:
|
324 |
+
print(
|
325 |
+
f"Idea generation converged after {j + 2} iterations."
|
326 |
+
)
|
327 |
+
break
|
328 |
+
|
329 |
+
idea_archive.append(json_output)
|
330 |
+
break
|
331 |
+
except Exception as e:
|
332 |
+
print(f"Failed to generate idea: {e}")
|
333 |
+
traceback.print_exc()
|
334 |
+
continue
|
335 |
+
|
336 |
+
## SAVE IDEAS
|
337 |
+
with open(osp.join(base_dir, "ideas.json"), "w") as f:
|
338 |
+
json.dump(idea_archive, f, indent=4)
|
339 |
+
|
340 |
+
return idea_archive
|
341 |
+
|
342 |
+
|
343 |
+
def on_backoff(details):
|
344 |
+
print(
|
345 |
+
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
|
346 |
+
f"calling function {details['target'].__name__} at {time.strftime('%X')}"
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
@backoff.on_exception(
|
351 |
+
backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
|
352 |
+
)
|
353 |
+
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
|
354 |
+
if not query:
|
355 |
+
return None
|
356 |
+
rsp = requests.get(
|
357 |
+
"https://api.semanticscholar.org/graph/v1/paper/search",
|
358 |
+
headers={"X-API-KEY": S2_API_KEY},
|
359 |
+
params={
|
360 |
+
"query": query,
|
361 |
+
"limit": result_limit,
|
362 |
+
"fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
|
363 |
+
},
|
364 |
+
)
|
365 |
+
print(f"Response Status Code: {rsp.status_code}")
|
366 |
+
print(
|
367 |
+
f"Response Content: {rsp.text[:500]}"
|
368 |
+
) # Print the first 500 characters of the response content
|
369 |
+
rsp.raise_for_status()
|
370 |
+
results = rsp.json()
|
371 |
+
total = results["total"]
|
372 |
+
if not total:
|
373 |
+
return None
|
374 |
+
time.sleep(2)
|
375 |
+
papers = results["data"]
|
376 |
+
return papers
|
377 |
+
|
378 |
+
|
379 |
+
novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
|
380 |
+
You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
|
381 |
+
Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
|
382 |
+
You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
|
383 |
+
The top 10 results for any search query will be presented to you with the abstracts.
|
384 |
+
|
385 |
+
You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
|
386 |
+
At any round, you may exit early and decide on the novelty of the idea.
|
387 |
+
Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
|
388 |
+
Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
|
389 |
+
|
390 |
+
{task_description}
|
391 |
+
<experiment.py>
|
392 |
+
{code}
|
393 |
+
</experiment.py>
|
394 |
+
"""
|
395 |
+
|
396 |
+
novelty_prompt = '''Round {current_round}/{num_rounds}.
|
397 |
+
You have this idea:
|
398 |
+
|
399 |
+
"""
|
400 |
+
{idea}
|
401 |
+
"""
|
402 |
+
|
403 |
+
The results of the last query are (empty on first round):
|
404 |
+
"""
|
405 |
+
{last_query_results}
|
406 |
+
"""
|
407 |
+
|
408 |
+
Respond in the following format:
|
409 |
+
|
410 |
+
THOUGHT:
|
411 |
+
<THOUGHT>
|
412 |
+
|
413 |
+
RESPONSE:
|
414 |
+
```json
|
415 |
+
<JSON>
|
416 |
+
```
|
417 |
+
|
418 |
+
In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision.
|
419 |
+
If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
|
420 |
+
|
421 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
422 |
+
- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
|
423 |
+
|
424 |
+
A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
|
425 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
426 |
+
'''
|
427 |
+
|
428 |
+
|
429 |
+
def check_idea_novelty(
|
430 |
+
ideas,
|
431 |
+
base_dir,
|
432 |
+
client,
|
433 |
+
model,
|
434 |
+
max_num_iterations=10,
|
435 |
+
):
|
436 |
+
with open(osp.join(base_dir, "experiment.py"), "r") as f:
|
437 |
+
code = f.read()
|
438 |
+
with open(osp.join(base_dir, "prompt.json"), "r") as f:
|
439 |
+
prompt = json.load(f)
|
440 |
+
task_description = prompt["task_description"]
|
441 |
+
|
442 |
+
for idx, idea in enumerate(ideas):
|
443 |
+
if "novel" in idea:
|
444 |
+
print(f"Skipping idea {idx}, already checked.")
|
445 |
+
continue
|
446 |
+
|
447 |
+
print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
|
448 |
+
|
449 |
+
novel = False
|
450 |
+
msg_history = []
|
451 |
+
papers_str = ""
|
452 |
+
|
453 |
+
for j in range(max_num_iterations):
|
454 |
+
try:
|
455 |
+
text, msg_history = get_response_from_llm(
|
456 |
+
novelty_prompt.format(
|
457 |
+
current_round=j + 1,
|
458 |
+
num_rounds=max_num_iterations,
|
459 |
+
idea=idea,
|
460 |
+
last_query_results=papers_str,
|
461 |
+
),
|
462 |
+
client=client,
|
463 |
+
model=model,
|
464 |
+
system_message=novelty_system_msg.format(
|
465 |
+
num_rounds=max_num_iterations,
|
466 |
+
task_description=task_description,
|
467 |
+
code=code,
|
468 |
+
),
|
469 |
+
msg_history=msg_history,
|
470 |
+
)
|
471 |
+
if "decision made: novel" in text.lower():
|
472 |
+
print("Decision made: novel after round", j)
|
473 |
+
novel = True
|
474 |
+
break
|
475 |
+
if "decision made: not novel" in text.lower():
|
476 |
+
print("Decision made: not novel after round", j)
|
477 |
+
break
|
478 |
+
|
479 |
+
## Format the content in JSON
|
480 |
+
text = format_novelty_json(text)
|
481 |
+
print("text after formating\n", text)
|
482 |
+
## PARSE OUTPUT
|
483 |
+
json_output = extract_json_between_markers(text)
|
484 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
485 |
+
|
486 |
+
## SEARCH FOR PAPERS
|
487 |
+
query = json_output["Query"]
|
488 |
+
papers = search_for_papers(query, result_limit=10)
|
489 |
+
if papers is None:
|
490 |
+
papers_str = "No papers found."
|
491 |
+
|
492 |
+
paper_strings = []
|
493 |
+
for i, paper in enumerate(papers):
|
494 |
+
paper_strings.append(
|
495 |
+
"""{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
|
496 |
+
i=i,
|
497 |
+
title=paper["title"],
|
498 |
+
authors=paper["authors"],
|
499 |
+
venue=paper["venue"],
|
500 |
+
year=paper["year"],
|
501 |
+
cites=paper["citationCount"],
|
502 |
+
abstract=paper["abstract"],
|
503 |
+
)
|
504 |
+
)
|
505 |
+
papers_str = "\n\n".join(paper_strings)
|
506 |
+
|
507 |
+
except Exception as e:
|
508 |
+
print(f"Error: {e}")
|
509 |
+
continue
|
510 |
+
|
511 |
+
idea["novel"] = novel
|
512 |
+
|
513 |
+
# Save results to JSON file
|
514 |
+
results_file = osp.join(base_dir, "ideas.json")
|
515 |
+
with open(results_file, "w") as f:
|
516 |
+
json.dump(ideas, f, indent=4)
|
517 |
+
|
518 |
+
return ideas
|
519 |
+
|
520 |
+
|
521 |
+
if __name__ == "__main__":
|
522 |
+
MAX_NUM_GENERATIONS = 32
|
523 |
+
NUM_REFLECTIONS = 5
|
524 |
+
import argparse
|
525 |
+
|
526 |
+
parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
|
527 |
+
# add type of experiment (nanoGPT, Boston, etc.)
|
528 |
+
parser.add_argument(
|
529 |
+
"--experiment",
|
530 |
+
type=str,
|
531 |
+
default="nanoGPT",
|
532 |
+
help="Experiment to run AI Scientist on.",
|
533 |
+
)
|
534 |
+
parser.add_argument(
|
535 |
+
"--model",
|
536 |
+
type=str,
|
537 |
+
default="deepseek-ai/DeepSeek-V2.5",
|
538 |
+
choices=allchoices,
|
539 |
+
help="Model to use for AI Scientist.",
|
540 |
+
)
|
541 |
+
parser.add_argument(
|
542 |
+
"--skip-idea-generation",
|
543 |
+
action="store_true",
|
544 |
+
help="Skip idea generation and use existing ideas.",
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
"--check-novelty",
|
548 |
+
action="store_true",
|
549 |
+
help="Check novelty of ideas.",
|
550 |
+
)
|
551 |
+
args = parser.parse_args()
|
552 |
+
|
553 |
+
# Create client
|
554 |
+
|
555 |
+
# ------------------------------------------------------------------------------------------------------
|
556 |
+
|
557 |
+
if args.model == "Qwen/Qwen2.5-72B-Instruct":
|
558 |
+
# elif args.model.startswith("hyperbolic"):
|
559 |
+
print(f"Welcome to the PARADISE of debug <generate_scientist.py> {args.model}.")
|
560 |
+
|
561 |
+
import openai
|
562 |
+
import os
|
563 |
+
# client_model = args.model[11:]
|
564 |
+
client_model = args.model
|
565 |
+
client = openai.OpenAI(
|
566 |
+
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
567 |
+
)
|
568 |
+
|
569 |
+
# ------------------------------------------------------------------------------------------------------
|
570 |
+
|
571 |
+
|
572 |
+
elif args.model == "claude-3-5-sonnet-20240620":
|
573 |
+
import anthropic
|
574 |
+
|
575 |
+
print(f"Using Anthropic API with model {args.model}.")
|
576 |
+
client_model = "claude-3-5-sonnet-20240620"
|
577 |
+
client = anthropic.Anthropic()
|
578 |
+
elif args.model.startswith("bedrock") and "claude" in args.model:
|
579 |
+
import anthropic
|
580 |
+
|
581 |
+
# Expects: bedrock/<MODEL_ID>
|
582 |
+
client_model = args.model.split("/")[-1]
|
583 |
+
|
584 |
+
print(f"Using Amazon Bedrock with model {client_model}.")
|
585 |
+
client = anthropic.AnthropicBedrock()
|
586 |
+
elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
|
587 |
+
import openai
|
588 |
+
|
589 |
+
print(f"Using OpenAI API with model {args.model}.")
|
590 |
+
client_model = "gpt-4o-2024-05-13"
|
591 |
+
client = openai.OpenAI()
|
592 |
+
elif args.model == "deepseek-coder-v2-0724":
|
593 |
+
import openai
|
594 |
+
|
595 |
+
print(f"Using OpenAI API with {args.model}.")
|
596 |
+
client_model = "deepseek-coder-v2-0724"
|
597 |
+
client = openai.OpenAI(
|
598 |
+
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
599 |
+
)
|
600 |
+
elif args.model == "llama3.1-405b":
|
601 |
+
import openai
|
602 |
+
|
603 |
+
print(f"Using OpenAI API with {args.model}.")
|
604 |
+
client_model = "meta-llama/llama-3.1-405b-instruct"
|
605 |
+
client = openai.OpenAI(
|
606 |
+
api_key=os.environ["OPENROUTER_API_KEY"],
|
607 |
+
base_url="https://openrouter.ai/api/v1",
|
608 |
+
)
|
609 |
+
elif args.model.startswith("ollama"):
|
610 |
+
import openai
|
611 |
+
|
612 |
+
print(f"Using Ollama with {args.model}.")
|
613 |
+
client_model = args.model.split("/")[-1]
|
614 |
+
# client_model = args.model
|
615 |
+
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
616 |
+
|
617 |
+
else:
|
618 |
+
raise ValueError(f"Model {args.model} not supported.")
|
619 |
+
|
620 |
+
base_dir = osp.join("templates", args.experiment)
|
621 |
+
results_dir = osp.join("results", args.experiment)
|
622 |
+
print("going into line 623...")
|
623 |
+
ideas = generate_ideas(
|
624 |
+
base_dir,
|
625 |
+
client=client,
|
626 |
+
model=client_model,
|
627 |
+
skip_generation=args.skip_idea_generation,
|
628 |
+
max_num_generations=MAX_NUM_GENERATIONS,
|
629 |
+
num_reflections=NUM_REFLECTIONS,
|
630 |
+
)
|
631 |
+
if args.check_novelty:
|
632 |
+
ideas = check_idea_novelty(
|
633 |
+
ideas,
|
634 |
+
base_dir=base_dir,
|
635 |
+
client=client,
|
636 |
+
model=client_model,
|
637 |
+
)
|
ai_scientist/llm.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import backoff
|
4 |
+
import openai
|
5 |
+
|
6 |
+
# Ollama
|
7 |
+
ollama_choices = [
|
8 |
+
"mistral-nemo",
|
9 |
+
"llama3.1",
|
10 |
+
"qwen2.5:32b"
|
11 |
+
]
|
12 |
+
|
13 |
+
# hyperbolic
|
14 |
+
hyperbolic_choices = [
|
15 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
16 |
+
"meta-llama/Meta-Llama-3.1-70B-Instruct",
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
allchoices = [
|
21 |
+
"Qwen/Qwen2.5-72B-Instruct",
|
22 |
+
"deepseek-ai/DeepSeek-V2.5",
|
23 |
+
"claude-3-5-sonnet-20240620",
|
24 |
+
"gpt-4o-2024-05-13",
|
25 |
+
"deepseek-coder-v2-0724",
|
26 |
+
"llama3.1-405b",
|
27 |
+
# Anthropic Claude models via Amazon Bedrock
|
28 |
+
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
29 |
+
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
30 |
+
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
31 |
+
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
32 |
+
]
|
33 |
+
|
34 |
+
for item in ollama_choices:
|
35 |
+
allchoices.append("ollama/" + item)
|
36 |
+
|
37 |
+
for item in hyperbolic_choices:
|
38 |
+
allchoices.append("hyperbolic/" + item)
|
39 |
+
|
40 |
+
|
41 |
+
# Get N responses from a single message, used for ensembling.
|
42 |
+
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
|
43 |
+
def get_batch_responses_from_llm(
|
44 |
+
msg,
|
45 |
+
client,
|
46 |
+
model,
|
47 |
+
system_message,
|
48 |
+
print_debug=False,
|
49 |
+
msg_history=None,
|
50 |
+
temperature=0.75,
|
51 |
+
n_responses=1,
|
52 |
+
):
|
53 |
+
if msg_history is None:
|
54 |
+
msg_history = []
|
55 |
+
|
56 |
+
if model in [
|
57 |
+
"gpt-4o-2024-05-13",
|
58 |
+
"gpt-4o-mini-2024-07-18",
|
59 |
+
"gpt-4o-2024-08-06",
|
60 |
+
"Qwen/Qwen2.5-72B-Instruct"
|
61 |
+
]:
|
62 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
63 |
+
response = client.chat.completions.create(
|
64 |
+
model=model,
|
65 |
+
messages=[
|
66 |
+
{"role": "system", "content": system_message},
|
67 |
+
*new_msg_history,
|
68 |
+
],
|
69 |
+
temperature=temperature,
|
70 |
+
max_tokens=3000,
|
71 |
+
n=n_responses,
|
72 |
+
stop=None,
|
73 |
+
seed=0,
|
74 |
+
)
|
75 |
+
content = [r.message.content for r in response.choices]
|
76 |
+
new_msg_history = [
|
77 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
78 |
+
]
|
79 |
+
elif model == "deepseek-coder-v2-0724":
|
80 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
81 |
+
response = client.chat.completions.create(
|
82 |
+
model="deepseek-coder",
|
83 |
+
messages=[
|
84 |
+
{"role": "system", "content": system_message},
|
85 |
+
*new_msg_history,
|
86 |
+
],
|
87 |
+
temperature=temperature,
|
88 |
+
max_tokens=3000,
|
89 |
+
n=n_responses,
|
90 |
+
stop=None,
|
91 |
+
)
|
92 |
+
content = [r.message.content for r in response.choices]
|
93 |
+
new_msg_history = [
|
94 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
95 |
+
]
|
96 |
+
|
97 |
+
# ------------------------------------------------------------------------------------------------------
|
98 |
+
|
99 |
+
elif model == "Qwen/Qwen2.5-72B-Instruct":
|
100 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
101 |
+
response = client.chat.completions.create(
|
102 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
103 |
+
messages=[
|
104 |
+
{"role": "system", "content": system_message},
|
105 |
+
*new_msg_history,
|
106 |
+
],
|
107 |
+
temperature=temperature,
|
108 |
+
max_tokens=3000,
|
109 |
+
n=n_responses,
|
110 |
+
stop=None,
|
111 |
+
)
|
112 |
+
content = [r.message.content for r in response.choices]
|
113 |
+
new_msg_history = [
|
114 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
115 |
+
]
|
116 |
+
|
117 |
+
# elif model in hyperbolic_choices:
|
118 |
+
# content, new_msg_history = [], []
|
119 |
+
# for i in range(n_responses):
|
120 |
+
# print(f"Getting {i+1}/{n_responses} response from {model}")
|
121 |
+
# c, hist = get_response_from_llm(
|
122 |
+
# msg,
|
123 |
+
# client,
|
124 |
+
# model,
|
125 |
+
# system_message,
|
126 |
+
# print_debug=False,
|
127 |
+
# msg_history=None,
|
128 |
+
# temperature=temperature,
|
129 |
+
# )
|
130 |
+
# content.append(c)
|
131 |
+
# new_msg_history.append(hist)
|
132 |
+
|
133 |
+
# ------------------------------------------------------------------------------------------------------
|
134 |
+
|
135 |
+
elif model == "llama-3-1-405b-instruct":
|
136 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
137 |
+
response = client.chat.completions.create(
|
138 |
+
model="meta-llama/llama-3.1-405b-instruct",
|
139 |
+
messages=[
|
140 |
+
{"role": "system", "content": system_message},
|
141 |
+
*new_msg_history,
|
142 |
+
],
|
143 |
+
temperature=temperature,
|
144 |
+
max_tokens=3000,
|
145 |
+
n=n_responses,
|
146 |
+
stop=None,
|
147 |
+
)
|
148 |
+
content = [r.message.content for r in response.choices]
|
149 |
+
new_msg_history = [
|
150 |
+
new_msg_history + [{"role": "assistant", "content": c}] for c in content
|
151 |
+
]
|
152 |
+
elif model == "claude-3-5-sonnet-20240620":
|
153 |
+
content, new_msg_history = [], []
|
154 |
+
for _ in range(n_responses):
|
155 |
+
c, hist = get_response_from_llm(
|
156 |
+
msg,
|
157 |
+
client,
|
158 |
+
model,
|
159 |
+
system_message,
|
160 |
+
print_debug=False,
|
161 |
+
msg_history=None,
|
162 |
+
temperature=temperature,
|
163 |
+
)
|
164 |
+
content.append(c)
|
165 |
+
new_msg_history.append(hist)
|
166 |
+
|
167 |
+
# ollama models
|
168 |
+
elif model in ollama_choices:
|
169 |
+
content, new_msg_history = [], []
|
170 |
+
for i in range(n_responses):
|
171 |
+
print(f"Getting {i+1}/{n_responses} response from {model}")
|
172 |
+
c, hist = get_response_from_llm(
|
173 |
+
msg,
|
174 |
+
client,
|
175 |
+
model,
|
176 |
+
system_message,
|
177 |
+
print_debug=False,
|
178 |
+
msg_history=None,
|
179 |
+
temperature=temperature,
|
180 |
+
)
|
181 |
+
content.append(c)
|
182 |
+
new_msg_history.append(hist)
|
183 |
+
else:
|
184 |
+
raise ValueError(f"Model {model} not supported.")
|
185 |
+
|
186 |
+
if print_debug:
|
187 |
+
# Just print the first one.
|
188 |
+
print()
|
189 |
+
print("*" * 20 + " LLM START " + "*" * 20)
|
190 |
+
for j, msg in enumerate(new_msg_history[0]):
|
191 |
+
print(f'{j}, {msg["role"]}: {msg["content"]}')
|
192 |
+
print(content)
|
193 |
+
print("*" * 21 + " LLM END " + "*" * 21)
|
194 |
+
print()
|
195 |
+
|
196 |
+
return content, new_msg_history
|
197 |
+
|
198 |
+
|
199 |
+
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
|
200 |
+
def get_response_from_llm(
|
201 |
+
msg,
|
202 |
+
client,
|
203 |
+
model,
|
204 |
+
system_message,
|
205 |
+
print_debug=False,
|
206 |
+
msg_history=None,
|
207 |
+
temperature=0.75,
|
208 |
+
):
|
209 |
+
if msg_history is None:
|
210 |
+
msg_history = []
|
211 |
+
|
212 |
+
if model == "claude-3-5-sonnet-20240620":
|
213 |
+
new_msg_history = msg_history + [
|
214 |
+
{
|
215 |
+
"role": "user",
|
216 |
+
"content": [
|
217 |
+
{
|
218 |
+
"type": "text",
|
219 |
+
"text": msg,
|
220 |
+
}
|
221 |
+
],
|
222 |
+
}
|
223 |
+
]
|
224 |
+
response = client.messages.create(
|
225 |
+
model="claude-3-5-sonnet-20240620",
|
226 |
+
max_tokens=3000,
|
227 |
+
temperature=temperature,
|
228 |
+
system=system_message,
|
229 |
+
messages=new_msg_history,
|
230 |
+
)
|
231 |
+
content = response.content[0].text
|
232 |
+
new_msg_history = new_msg_history + [
|
233 |
+
{
|
234 |
+
"role": "assistant",
|
235 |
+
"content": [
|
236 |
+
{
|
237 |
+
"type": "text",
|
238 |
+
"text": content,
|
239 |
+
}
|
240 |
+
],
|
241 |
+
}
|
242 |
+
]
|
243 |
+
# ------------------------------------------------------------------------------------------------------
|
244 |
+
|
245 |
+
elif model in [
|
246 |
+
"gpt-4o-2024-05-13",
|
247 |
+
"gpt-4o-mini-2024-07-18",
|
248 |
+
"gpt-4o-2024-08-06",
|
249 |
+
"Qwen/Qwen2.5-72B-Instruct"
|
250 |
+
]:
|
251 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
252 |
+
response = client.chat.completions.create(
|
253 |
+
model=model,
|
254 |
+
messages=[
|
255 |
+
{"role": "system", "content": system_message},
|
256 |
+
*new_msg_history,
|
257 |
+
],
|
258 |
+
temperature=temperature,
|
259 |
+
max_tokens=3000,
|
260 |
+
n=1,
|
261 |
+
stop=None,
|
262 |
+
seed=0,
|
263 |
+
)
|
264 |
+
content = response.choices[0].message.content
|
265 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
266 |
+
|
267 |
+
|
268 |
+
# ------------------------------------------------------------------------------------------------------
|
269 |
+
|
270 |
+
|
271 |
+
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
|
272 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
273 |
+
response = client.chat.completions.create(
|
274 |
+
model="meta-llama/llama-3.1-405b-instruct",
|
275 |
+
messages=[
|
276 |
+
{"role": "system", "content": system_message},
|
277 |
+
*new_msg_history,
|
278 |
+
],
|
279 |
+
temperature=temperature,
|
280 |
+
max_tokens=3000,
|
281 |
+
n=1,
|
282 |
+
stop=None,
|
283 |
+
)
|
284 |
+
content = response.choices[0].message.content
|
285 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
286 |
+
|
287 |
+
|
288 |
+
elif model in ollama_choices:
|
289 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
290 |
+
response = client.chat.completions.create(
|
291 |
+
model=model,
|
292 |
+
messages=[
|
293 |
+
{"role": "system", "content": system_message},
|
294 |
+
*new_msg_history,
|
295 |
+
],
|
296 |
+
temperature=temperature,
|
297 |
+
max_tokens=6000,
|
298 |
+
n=1,
|
299 |
+
stop=None,
|
300 |
+
seed=0,
|
301 |
+
)
|
302 |
+
content = response.choices[0].message.content
|
303 |
+
# print("\nget_response_from_llm\n")
|
304 |
+
# print(content)
|
305 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
306 |
+
|
307 |
+
else:
|
308 |
+
raise ValueError(f"Model {model} not supported.")
|
309 |
+
|
310 |
+
if print_debug:
|
311 |
+
print()
|
312 |
+
print("*" * 20 + " LLM START " + "*" * 20)
|
313 |
+
for j, msg in enumerate(new_msg_history):
|
314 |
+
print(f'{j}, {msg["role"]}: {msg["content"]}')
|
315 |
+
print(content)
|
316 |
+
print("*" * 21 + " LLM END " + "*" * 21)
|
317 |
+
print()
|
318 |
+
|
319 |
+
return content, new_msg_history
|
320 |
+
|
321 |
+
|
322 |
+
def llm_json_auto_correct(system_prompt: str, user_prompt: str) -> str:
|
323 |
+
import os
|
324 |
+
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1")
|
325 |
+
response = client.chat.completions.create(
|
326 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
327 |
+
temperature=0,
|
328 |
+
messages=[
|
329 |
+
{"role": "system", "content": system_prompt},
|
330 |
+
{"role": "user", "content": user_prompt},
|
331 |
+
],
|
332 |
+
)
|
333 |
+
return response.choices[0].message.content
|
334 |
+
|
335 |
+
|
336 |
+
def extract_json_between_markers(llm_output):
|
337 |
+
json_start_marker = "```json"
|
338 |
+
json_end_marker = "```"
|
339 |
+
|
340 |
+
# Find the start and end indices of the JSON string
|
341 |
+
start_index = llm_output.find(json_start_marker)
|
342 |
+
if start_index != -1:
|
343 |
+
start_index += len(json_start_marker) # Move past the marker
|
344 |
+
end_index = llm_output.find(json_end_marker, start_index)
|
345 |
+
else:
|
346 |
+
return None # JSON markers not found
|
347 |
+
|
348 |
+
if end_index == -1:
|
349 |
+
return None # End marker not found
|
350 |
+
|
351 |
+
# Extract the JSON string
|
352 |
+
json_string = llm_output[start_index:end_index].strip()
|
353 |
+
# print(json_string)
|
354 |
+
try:
|
355 |
+
parsed_json = json.loads(json_string)
|
356 |
+
|
357 |
+
return parsed_json
|
358 |
+
except json.JSONDecodeError:
|
359 |
+
return None # Invalid JSON format
|
ai_scientist/perform_experiments.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import os.path as osp
|
3 |
+
import subprocess
|
4 |
+
from subprocess import TimeoutExpired
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
|
8 |
+
MAX_ITERS = 4
|
9 |
+
MAX_RUNS = 5
|
10 |
+
MAX_STDERR_OUTPUT = 1500
|
11 |
+
|
12 |
+
coder_prompt = """Your goal is to implement the following idea: {title}.
|
13 |
+
The proposed experiment is as follows: {idea}.
|
14 |
+
You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
|
15 |
+
|
16 |
+
First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
|
17 |
+
|
18 |
+
Note that we already provide the vanilla baseline results, so you do not need to re-run it.
|
19 |
+
|
20 |
+
For reference, the baseline results are as follows:
|
21 |
+
|
22 |
+
{baseline_results}
|
23 |
+
|
24 |
+
After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
|
25 |
+
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
|
26 |
+
You can then implement the next thing on your list."""
|
27 |
+
|
28 |
+
|
29 |
+
# RUN EXPERIMENT
|
30 |
+
def run_experiment(folder_name, run_num, timeout=7200):
|
31 |
+
cwd = osp.abspath(folder_name)
|
32 |
+
# COPY CODE SO WE CAN SEE IT.
|
33 |
+
shutil.copy(
|
34 |
+
osp.join(folder_name, "experiment.py"),
|
35 |
+
osp.join(folder_name, f"run_{run_num}.py"),
|
36 |
+
)
|
37 |
+
|
38 |
+
# LAUNCH COMMAND
|
39 |
+
command = [
|
40 |
+
"python",
|
41 |
+
"experiment.py",
|
42 |
+
f"--out_dir=run_{run_num}",
|
43 |
+
]
|
44 |
+
try:
|
45 |
+
result = subprocess.run(
|
46 |
+
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
|
47 |
+
)
|
48 |
+
|
49 |
+
if result.stderr:
|
50 |
+
print(result.stderr, file=sys.stderr)
|
51 |
+
|
52 |
+
if result.returncode != 0:
|
53 |
+
print(f"Run {run_num} failed with return code {result.returncode}")
|
54 |
+
if osp.exists(osp.join(cwd, f"run_{run_num}")):
|
55 |
+
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
|
56 |
+
print(f"Run failed with the following error {result.stderr}")
|
57 |
+
stderr_output = result.stderr
|
58 |
+
if len(stderr_output) > MAX_STDERR_OUTPUT:
|
59 |
+
stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
|
60 |
+
next_prompt = f"Run failed with the following error {stderr_output}"
|
61 |
+
else:
|
62 |
+
with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
|
63 |
+
results = json.load(f)
|
64 |
+
results = {k: v["means"] for k, v in results.items()}
|
65 |
+
|
66 |
+
next_prompt = f"""Run {run_num} completed. Here are the results:
|
67 |
+
{results}
|
68 |
+
|
69 |
+
Decide if you need to re-plan your experiments given the result (you often will not need to).
|
70 |
+
|
71 |
+
Someone else will be using `notes.txt` to perform a writeup on this in the future.
|
72 |
+
Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
|
73 |
+
|
74 |
+
Then, implement the next thing on your list.
|
75 |
+
We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
|
76 |
+
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
|
77 |
+
If you are finished with experiments, respond with 'ALL_COMPLETED'."""
|
78 |
+
return result.returncode, next_prompt
|
79 |
+
except TimeoutExpired:
|
80 |
+
print(f"Run {run_num} timed out after {timeout} seconds")
|
81 |
+
if osp.exists(osp.join(cwd, f"run_{run_num}")):
|
82 |
+
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
|
83 |
+
next_prompt = f"Run timed out after {timeout} seconds"
|
84 |
+
return 1, next_prompt
|
85 |
+
|
86 |
+
|
87 |
+
# RUN PLOTTING
|
88 |
+
def run_plotting(folder_name, timeout=600):
|
89 |
+
cwd = osp.abspath(folder_name)
|
90 |
+
# LAUNCH COMMAND
|
91 |
+
command = [
|
92 |
+
"python",
|
93 |
+
"plot.py",
|
94 |
+
]
|
95 |
+
try:
|
96 |
+
result = subprocess.run(
|
97 |
+
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
|
98 |
+
)
|
99 |
+
|
100 |
+
if result.stderr:
|
101 |
+
print(result.stderr, file=sys.stderr)
|
102 |
+
|
103 |
+
if result.returncode != 0:
|
104 |
+
print(f"Plotting failed with return code {result.returncode}")
|
105 |
+
next_prompt = f"Plotting failed with the following error {result.stderr}"
|
106 |
+
else:
|
107 |
+
next_prompt = ""
|
108 |
+
return result.returncode, next_prompt
|
109 |
+
except TimeoutExpired:
|
110 |
+
print(f"Plotting timed out after {timeout} seconds")
|
111 |
+
next_prompt = f"Plotting timed out after {timeout} seconds"
|
112 |
+
return 1, next_prompt
|
113 |
+
|
114 |
+
|
115 |
+
# PERFORM EXPERIMENTS
|
116 |
+
def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
|
117 |
+
## RUN EXPERIMENT
|
118 |
+
current_iter = 0
|
119 |
+
run = 1
|
120 |
+
next_prompt = coder_prompt.format(
|
121 |
+
title=idea["Title"],
|
122 |
+
idea=idea["Experiment"],
|
123 |
+
max_runs=MAX_RUNS,
|
124 |
+
baseline_results=baseline_results,
|
125 |
+
)
|
126 |
+
while run < MAX_RUNS + 1:
|
127 |
+
if current_iter >= MAX_ITERS:
|
128 |
+
print("Max iterations reached")
|
129 |
+
break
|
130 |
+
coder_out = coder.run(next_prompt)
|
131 |
+
print(coder_out)
|
132 |
+
if "ALL_COMPLETED" in coder_out:
|
133 |
+
break
|
134 |
+
return_code, next_prompt = run_experiment(folder_name, run)
|
135 |
+
if return_code == 0:
|
136 |
+
run += 1
|
137 |
+
current_iter = 0
|
138 |
+
current_iter += 1
|
139 |
+
if current_iter >= MAX_ITERS:
|
140 |
+
print("Not all experiments completed.")
|
141 |
+
return False
|
142 |
+
|
143 |
+
current_iter = 0
|
144 |
+
next_prompt = """
|
145 |
+
Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
|
146 |
+
|
147 |
+
In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
|
148 |
+
|
149 |
+
Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
|
150 |
+
|
151 |
+
We will be running the command `python plot.py` to generate the plots.
|
152 |
+
"""
|
153 |
+
while True:
|
154 |
+
coder_out = coder.run(next_prompt)
|
155 |
+
return_code, next_prompt = run_plotting(folder_name)
|
156 |
+
current_iter += 1
|
157 |
+
if return_code == 0 or current_iter >= MAX_ITERS:
|
158 |
+
break
|
159 |
+
next_prompt = """
|
160 |
+
Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
|
161 |
+
|
162 |
+
Somebody else will be using `notes.txt` to write a report on this in the future.
|
163 |
+
"""
|
164 |
+
coder.run(next_prompt)
|
165 |
+
|
166 |
+
return True
|
ai_scientist/perform_review.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pymupdf
|
6 |
+
import pymupdf4llm
|
7 |
+
from pypdf import PdfReader
|
8 |
+
from strictjson import strict_json
|
9 |
+
|
10 |
+
from ai_scientist.llm import (
|
11 |
+
extract_json_between_markers,
|
12 |
+
get_batch_responses_from_llm,
|
13 |
+
get_response_from_llm,
|
14 |
+
llm_json_auto_correct,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
# Format the content in JSON
|
19 |
+
def format_llm_review_json(text):
|
20 |
+
res = strict_json(
|
21 |
+
system_prompt="You are a JSON formatter",
|
22 |
+
user_prompt=text,
|
23 |
+
return_as_json=True,
|
24 |
+
output_format={
|
25 |
+
"Summary": "A summary of the paper content and its contributions.",
|
26 |
+
"Strengths": "A list of strengths of the paper, type: list",
|
27 |
+
"Weaknesses": "A list of weaknesses of the paper, type: list",
|
28 |
+
"Originality": "A rating from 1 to 4 (low, medium, high, very high), type: int",
|
29 |
+
"Quality": "A rating from 1 to 4 (low, medium, high, very high), type: int",
|
30 |
+
"Clarity": "A rating from 1 to 4 (low, medium, high, very high), type: int",
|
31 |
+
"Significance": "A rating from 1 to 4 (low, medium, high, very high), type: int",
|
32 |
+
"Questions": "A set of clarifying questions to be answered by the paper authors, type: list",
|
33 |
+
"Limitations": "A set of limitations and potential negative societal impacts of the work, type: str",
|
34 |
+
"Ethical Concerns": "A boolean value indicating whether there are ethical concerns, type: bool",
|
35 |
+
"Soundness": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
|
36 |
+
"Presentation": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
|
37 |
+
"Contribution": "A rating from 1 to 4 (poor, fair, good, excellent), type: int",
|
38 |
+
"Overall": "A rating from 1 to 10 (very strong reject to award quality), type: int",
|
39 |
+
"Confidence": "A rating from 1 to 5 (low, medium, high, very high, absolute), type: int",
|
40 |
+
"Decision": "A decision that has to be Accept or Reject, type: str",
|
41 |
+
},
|
42 |
+
llm=llm_json_auto_correct,
|
43 |
+
)
|
44 |
+
text = json.loads(res)
|
45 |
+
return text
|
46 |
+
|
47 |
+
|
48 |
+
reviewer_system_prompt_base = (
|
49 |
+
"You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue."
|
50 |
+
"Be critical and cautious in your decision."
|
51 |
+
)
|
52 |
+
|
53 |
+
reviewer_system_prompt_neg = (
|
54 |
+
reviewer_system_prompt_base
|
55 |
+
+ "If a paper is bad or you are unsure, give it bad scores and reject it."
|
56 |
+
)
|
57 |
+
reviewer_system_prompt_pos = (
|
58 |
+
reviewer_system_prompt_base
|
59 |
+
+ "If a paper is good or you are unsure, give it good scores and accept it."
|
60 |
+
)
|
61 |
+
|
62 |
+
template_instructions = """
|
63 |
+
Respond in the following format:
|
64 |
+
|
65 |
+
THOUGHT:
|
66 |
+
<THOUGHT>
|
67 |
+
|
68 |
+
REVIEW JSON:
|
69 |
+
```json
|
70 |
+
<JSON>
|
71 |
+
```
|
72 |
+
|
73 |
+
In <THOUGHT>, first briefly discuss your intuitions and reasoning for the evaluation.
|
74 |
+
Detail your high-level arguments, necessary choices and desired outcomes of the review.
|
75 |
+
Do not make generic comments here, but be specific to your current paper.
|
76 |
+
Treat this as the note-taking phase of your review.
|
77 |
+
|
78 |
+
In <JSON>, provide the review in JSON format with the following fields in the order:
|
79 |
+
- "Summary": A summary of the paper content and its contributions.
|
80 |
+
- "Strengths": A list of strengths of the paper.
|
81 |
+
- "Weaknesses": A list of weaknesses of the paper.
|
82 |
+
- "Originality": A rating from 1 to 4 (low, medium, high, very high).
|
83 |
+
- "Quality": A rating from 1 to 4 (low, medium, high, very high).
|
84 |
+
- "Clarity": A rating from 1 to 4 (low, medium, high, very high).
|
85 |
+
- "Significance": A rating from 1 to 4 (low, medium, high, very high).
|
86 |
+
- "Questions": A set of clarifying questions to be answered by the paper authors.
|
87 |
+
- "Limitations": A set of limitations and potential negative societal impacts of the work.
|
88 |
+
- "Ethical Concerns": A boolean value indicating whether there are ethical concerns.
|
89 |
+
- "Soundness": A rating from 1 to 4 (poor, fair, good, excellent).
|
90 |
+
- "Presentation": A rating from 1 to 4 (poor, fair, good, excellent).
|
91 |
+
- "Contribution": A rating from 1 to 4 (poor, fair, good, excellent).
|
92 |
+
- "Overall": A rating from 1 to 10 (very strong reject to award quality).
|
93 |
+
- "Confidence": A rating from 1 to 5 (low, medium, high, very high, absolute).
|
94 |
+
- "Decision": A decision that has to be one of the following: Accept, Reject.
|
95 |
+
|
96 |
+
For the "Decision" field, don't use Weak Accept, Borderline Accept, Borderline Reject, or Strong Reject. Instead, only use Accept or Reject.
|
97 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
98 |
+
"""
|
99 |
+
|
100 |
+
neurips_form = (
|
101 |
+
"""
|
102 |
+
## Review Form
|
103 |
+
Below is a description of the questions you will be asked on the review form for each paper and some guidelines on what to consider when answering these questions.
|
104 |
+
When writing your review, please keep in mind that after decisions have been made, reviews and meta-reviews of accepted papers and opted-in rejected papers will be made public.
|
105 |
+
|
106 |
+
1. Summary: Briefly summarize the paper and its contributions. This is not the place to critique the paper; the authors should generally agree with a well-written summary.
|
107 |
+
- Strengths and Weaknesses: Please provide a thorough assessment of the strengths and weaknesses of the paper, touching on each of the following dimensions:
|
108 |
+
- Originality: Are the tasks or methods new? Is the work a novel combination of well-known techniques? (This can be valuable!) Is it clear how this work differs from previous contributions? Is related work adequately cited
|
109 |
+
- Quality: Is the submission technically sound? Are claims well supported (e.g., by theoretical analysis or experimental results)? Are the methods used appropriate? Is this a complete piece of work or work in progress? Are the authors careful and honest about evaluating both the strengths and weaknesses of their work
|
110 |
+
- Clarity: Is the submission clearly written? Is it well organized? (If not, please make constructive suggestions for improving its clarity.) Does it adequately inform the reader? (Note that a superbly written paper provides enough information for an expert reader to reproduce its results.)
|
111 |
+
- Significance: Are the results important? Are others (researchers or practitioners) likely to use the ideas or build on them? Does the submission address a difficult task in a better way than previous work? Does it advance the state of the art in a demonstrable way? Does it provide unique data, unique conclusions about existing data, or a unique theoretical or experimental approach?
|
112 |
+
|
113 |
+
2. Questions: Please list up and carefully describe any questions and suggestions for the authors. Think of the things where a response from the author can change your opinion, clarify a confusion or address a limitation. This can be very important for a productive rebuttal and discussion phase with the authors.
|
114 |
+
|
115 |
+
3. Limitations: Have the authors adequately addressed the limitations and potential negative societal impact of their work? If not, please include constructive suggestions for improvement.
|
116 |
+
In general, authors should be rewarded rather than punished for being up front about the limitations of their work and any potential negative societal impact. You are encouraged to think through whether any critical points are missing and provide these as feedback for the authors.
|
117 |
+
|
118 |
+
4. Ethical concerns: If there are ethical issues with this paper, please flag the paper for an ethics review. For guidance on when this is appropriate, please review the NeurIPS ethics guidelines.
|
119 |
+
|
120 |
+
5. Soundness: Please assign the paper a numerical rating on the following scale to indicate the soundness of the technical claims, experimental and research methodology and on whether the central claims of the paper are adequately supported with evidence.
|
121 |
+
4: excellent
|
122 |
+
3: good
|
123 |
+
2: fair
|
124 |
+
1: poor
|
125 |
+
|
126 |
+
6. Presentation: Please assign the paper a numerical rating on the following scale to indicate the quality of the presentation. This should take into account the writing style and clarity, as well as contextualization relative to prior work.
|
127 |
+
4: excellent
|
128 |
+
3: good
|
129 |
+
2: fair
|
130 |
+
1: poor
|
131 |
+
|
132 |
+
7. Contribution: Please assign the paper a numerical rating on the following scale to indicate the quality of the overall contribution this paper makes to the research area being studied. Are the questions being asked important? Does the paper bring a significant originality of ideas and/or execution? Are the results valuable to share with the broader NeurIPS community.
|
133 |
+
4: excellent
|
134 |
+
3: good
|
135 |
+
2: fair
|
136 |
+
1: poor
|
137 |
+
|
138 |
+
8. Overall: Please provide an "overall score" for this submission. Choices:
|
139 |
+
10: Award quality: Technically flawless paper with groundbreaking impact on one or more areas of AI, with exceptionally strong evaluation, reproducibility, and resources, and no unaddressed ethical considerations.
|
140 |
+
9: Very Strong Accept: Technically flawless paper with groundbreaking impact on at least one area of AI and excellent impact on multiple areas of AI, with flawless evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
|
141 |
+
8: Strong Accept: Technically strong paper with, with novel ideas, excellent impact on at least one area of AI or high-to-excellent impact on multiple areas of AI, with excellent evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
|
142 |
+
7: Accept: Technically solid paper, with high impact on at least one sub-area of AI or moderate-to-high impact on more than one area of AI, with good-to-excellent evaluation, resources, reproducibility, and no unaddressed ethical considerations.
|
143 |
+
6: Weak Accept: Technically solid, moderate-to-high impact paper, with no major concerns with respect to evaluation, resources, reproducibility, ethical considerations.
|
144 |
+
5: Borderline accept: Technically solid paper where reasons to accept outweigh reasons to reject, e.g., limited evaluation. Please use sparingly.
|
145 |
+
4: Borderline reject: Technically solid paper where reasons to reject, e.g., limited evaluation, outweigh reasons to accept, e.g., good evaluation. Please use sparingly.
|
146 |
+
3: Reject: For instance, a paper with technical flaws, weak evaluation, inadequate reproducibility and incompletely addressed ethical considerations.
|
147 |
+
2: Strong Reject: For instance, a paper with major technical flaws, and/or poor evaluation, limited impact, poor reproducibility and mostly unaddressed ethical considerations.
|
148 |
+
1: Very Strong Reject: For instance, a paper with trivial results or unaddressed ethical considerations
|
149 |
+
|
150 |
+
9. Confidence: Please provide a "confidence score" for your assessment of this submission to indicate how confident you are in your evaluation. Choices:
|
151 |
+
5: You are absolutely certain about your assessment. You are very familiar with the related work and checked the math/other details carefully.
|
152 |
+
4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work.
|
153 |
+
3: You are fairly confident in your assessment. It is possible that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
|
154 |
+
2: You are willing to defend your assessment, but it is quite likely that you did not understand the central parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
|
155 |
+
1: Your assessment is an educated guess. The submission is not in your area or the submission was difficult to understand. Math/other details were not carefully checked.
|
156 |
+
"""
|
157 |
+
+ template_instructions
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def perform_review(
|
162 |
+
text,
|
163 |
+
model,
|
164 |
+
client,
|
165 |
+
num_reflections=1,
|
166 |
+
num_fs_examples=1,
|
167 |
+
num_reviews_ensemble=1,
|
168 |
+
temperature=0.75,
|
169 |
+
msg_history=None,
|
170 |
+
return_msg_history=False,
|
171 |
+
reviewer_system_prompt=reviewer_system_prompt_neg,
|
172 |
+
review_instruction_form=neurips_form,
|
173 |
+
):
|
174 |
+
if num_fs_examples > 0:
|
175 |
+
fs_prompt = get_review_fewshot_examples(num_fs_examples)
|
176 |
+
base_prompt = review_instruction_form + fs_prompt
|
177 |
+
else:
|
178 |
+
base_prompt = review_instruction_form
|
179 |
+
|
180 |
+
base_prompt += f"""
|
181 |
+
Here is the paper you are asked to review:
|
182 |
+
```
|
183 |
+
{text}
|
184 |
+
```"""
|
185 |
+
|
186 |
+
if num_reviews_ensemble > 1:
|
187 |
+
llm_review, msg_histories = get_batch_responses_from_llm(
|
188 |
+
base_prompt,
|
189 |
+
model=model,
|
190 |
+
client=client,
|
191 |
+
system_message=reviewer_system_prompt,
|
192 |
+
print_debug=False,
|
193 |
+
msg_history=msg_history,
|
194 |
+
# Higher temperature to encourage diversity.
|
195 |
+
temperature=0.75,
|
196 |
+
n_responses=num_reviews_ensemble,
|
197 |
+
)
|
198 |
+
parsed_reviews = []
|
199 |
+
for idx, rev in enumerate(llm_review):
|
200 |
+
try:
|
201 |
+
parsed_reviews.append(format_llm_review_json(rev))
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Ensemble review {idx} failed: {e}")
|
204 |
+
parsed_reviews = [r for r in parsed_reviews if r is not None]
|
205 |
+
review = get_meta_review(model, client, temperature, parsed_reviews)
|
206 |
+
## Format the content in JSON
|
207 |
+
review = format_llm_review_json(review)
|
208 |
+
# take first valid in case meta-reviewer fails
|
209 |
+
if review is None:
|
210 |
+
review = parsed_reviews[0]
|
211 |
+
# print(parsed_reviews, "\n\n\n", review) # debug
|
212 |
+
# Replace numerical scores with the average of the ensemble.
|
213 |
+
for score, limits in [
|
214 |
+
("Originality", (1, 4)),
|
215 |
+
("Quality", (1, 4)),
|
216 |
+
("Clarity", (1, 4)),
|
217 |
+
("Significance", (1, 4)),
|
218 |
+
("Soundness", (1, 4)),
|
219 |
+
("Presentation", (1, 4)),
|
220 |
+
("Contribution", (1, 4)),
|
221 |
+
("Overall", (1, 10)),
|
222 |
+
("Confidence", (1, 5)),
|
223 |
+
]:
|
224 |
+
scores = []
|
225 |
+
for r in parsed_reviews:
|
226 |
+
if score in r and limits[1] >= r[score] >= limits[0]:
|
227 |
+
scores.append(r[score])
|
228 |
+
review[score] = int(round(np.mean(scores)))
|
229 |
+
|
230 |
+
# Rewrite the message history with the valid one and new aggregated review.
|
231 |
+
msg_history = msg_histories[0][:-1]
|
232 |
+
msg_history += [
|
233 |
+
{
|
234 |
+
"role": "assistant",
|
235 |
+
"content": f"""
|
236 |
+
THOUGHT:
|
237 |
+
I will start by aggregating the opinions of {num_reviews_ensemble} reviewers that I previously obtained.
|
238 |
+
|
239 |
+
REVIEW JSON:
|
240 |
+
```json
|
241 |
+
{json.dumps(review)}
|
242 |
+
```
|
243 |
+
""",
|
244 |
+
}
|
245 |
+
]
|
246 |
+
else:
|
247 |
+
llm_review, msg_history = get_response_from_llm(
|
248 |
+
base_prompt,
|
249 |
+
model=model,
|
250 |
+
client=client,
|
251 |
+
system_message=reviewer_system_prompt,
|
252 |
+
print_debug=False,
|
253 |
+
msg_history=msg_history,
|
254 |
+
temperature=temperature,
|
255 |
+
)
|
256 |
+
review = format_llm_review_json(llm_review)
|
257 |
+
|
258 |
+
if num_reflections > 1:
|
259 |
+
for j in range(num_reflections - 1):
|
260 |
+
# print(f"Relection: {j + 2}/{num_reflections}")
|
261 |
+
text, msg_history = get_response_from_llm(
|
262 |
+
reviewer_reflection_prompt,
|
263 |
+
client=client,
|
264 |
+
model=model,
|
265 |
+
system_message=reviewer_system_prompt,
|
266 |
+
msg_history=msg_history,
|
267 |
+
temperature=temperature,
|
268 |
+
)
|
269 |
+
review = format_llm_review_json(text)
|
270 |
+
assert review is not None, "Failed to extract JSON from LLM output"
|
271 |
+
|
272 |
+
if "I am done" in text:
|
273 |
+
# print(f"Review generation converged after {j + 2} iterations.")
|
274 |
+
break
|
275 |
+
|
276 |
+
if return_msg_history:
|
277 |
+
return review, msg_history
|
278 |
+
else:
|
279 |
+
return review
|
280 |
+
|
281 |
+
|
282 |
+
reviewer_reflection_prompt = """Round {current_round}/{num_reflections}.
|
283 |
+
In your thoughts, first carefully consider the accuracy and soundness of the review you just created.
|
284 |
+
Include any other factors that you think are important in evaluating the paper.
|
285 |
+
Ensure the review is clear and concise, and the JSON is in the correct format.
|
286 |
+
Do not make things overly complicated.
|
287 |
+
In the next attempt, try and refine and improve your review.
|
288 |
+
Stick to the spirit of the original review unless there are glaring issues.
|
289 |
+
|
290 |
+
Respond in the same format as before:
|
291 |
+
THOUGHT:
|
292 |
+
<THOUGHT>
|
293 |
+
|
294 |
+
REVIEW JSON:
|
295 |
+
```json
|
296 |
+
<JSON>
|
297 |
+
```
|
298 |
+
|
299 |
+
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
|
300 |
+
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
|
301 |
+
|
302 |
+
|
303 |
+
def load_paper(pdf_path, num_pages=None, min_size=100):
|
304 |
+
try:
|
305 |
+
if num_pages is None:
|
306 |
+
text = pymupdf4llm.to_markdown(pdf_path)
|
307 |
+
else:
|
308 |
+
reader = PdfReader(pdf_path)
|
309 |
+
min_pages = min(len(reader.pages), num_pages)
|
310 |
+
text = pymupdf4llm.to_markdown(pdf_path, pages=list(range(min_pages)))
|
311 |
+
if len(text) < min_size:
|
312 |
+
raise Exception("Text too short")
|
313 |
+
except Exception as e:
|
314 |
+
print(f"Error with pymupdf4llm, falling back to pymupdf: {e}")
|
315 |
+
try:
|
316 |
+
doc = pymupdf.open(pdf_path) # open a document
|
317 |
+
if num_pages:
|
318 |
+
doc = doc[:num_pages]
|
319 |
+
text = ""
|
320 |
+
for page in doc: # iterate the document pages
|
321 |
+
text = text + page.get_text() # get plain text encoded as UTF-8
|
322 |
+
if len(text) < min_size:
|
323 |
+
raise Exception("Text too short")
|
324 |
+
except Exception as e:
|
325 |
+
print(f"Error with pymupdf, falling back to pypdf: {e}")
|
326 |
+
reader = PdfReader(pdf_path)
|
327 |
+
if num_pages is None:
|
328 |
+
text = "".join(page.extract_text() for page in reader.pages)
|
329 |
+
else:
|
330 |
+
text = "".join(page.extract_text() for page in reader.pages[:num_pages])
|
331 |
+
if len(text) < min_size:
|
332 |
+
raise Exception("Text too short")
|
333 |
+
|
334 |
+
return text
|
335 |
+
|
336 |
+
|
337 |
+
def load_review(path):
|
338 |
+
with open(path, "r") as json_file:
|
339 |
+
loaded = json.load(json_file)
|
340 |
+
return loaded["review"]
|
341 |
+
|
342 |
+
|
343 |
+
# get directory of this file
|
344 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
345 |
+
|
346 |
+
fewshot_papers = [
|
347 |
+
os.path.join(dir_path, "fewshot_examples/132_automated_relational.pdf"),
|
348 |
+
os.path.join(dir_path, "fewshot_examples/attention.pdf"),
|
349 |
+
os.path.join(dir_path, "fewshot_examples/2_carpe_diem.pdf"),
|
350 |
+
]
|
351 |
+
|
352 |
+
fewshot_reviews = [
|
353 |
+
os.path.join(dir_path, "fewshot_examples/132_automated_relational.json"),
|
354 |
+
os.path.join(dir_path, "fewshot_examples/attention.json"),
|
355 |
+
os.path.join(dir_path, "fewshot_examples/2_carpe_diem.json"),
|
356 |
+
]
|
357 |
+
|
358 |
+
|
359 |
+
def get_review_fewshot_examples(num_fs_examples=1):
|
360 |
+
fewshot_prompt = """
|
361 |
+
Below are some sample reviews, copied from previous machine learning conferences.
|
362 |
+
Note that while each review is formatted differently according to each reviewer's style, the reviews are well-structured and therefore easy to navigate.
|
363 |
+
"""
|
364 |
+
for paper, review in zip(
|
365 |
+
fewshot_papers[:num_fs_examples], fewshot_reviews[:num_fs_examples]
|
366 |
+
):
|
367 |
+
txt_path = paper.replace(".pdf", ".txt")
|
368 |
+
if os.path.exists(txt_path):
|
369 |
+
with open(txt_path, "r") as f:
|
370 |
+
paper_text = f.read()
|
371 |
+
else:
|
372 |
+
paper_text = load_paper(paper)
|
373 |
+
review_text = load_review(review)
|
374 |
+
fewshot_prompt += f"""
|
375 |
+
Paper:
|
376 |
+
|
377 |
+
```
|
378 |
+
{paper_text}
|
379 |
+
```
|
380 |
+
|
381 |
+
Review:
|
382 |
+
|
383 |
+
```
|
384 |
+
{review_text}
|
385 |
+
```
|
386 |
+
"""
|
387 |
+
|
388 |
+
return fewshot_prompt
|
389 |
+
|
390 |
+
|
391 |
+
meta_reviewer_system_prompt = """You are an Area Chair at a machine learning conference.
|
392 |
+
You are in charge of meta-reviewing a paper that was reviewed by {reviewer_count} reviewers.
|
393 |
+
Your job is to aggregate the reviews into a single meta-review in the same format.
|
394 |
+
Be critical and cautious in your decision, find consensus, and respect the opinion of all the reviewers."""
|
395 |
+
|
396 |
+
|
397 |
+
def get_meta_review(model, client, temperature, reviews):
|
398 |
+
# Write a meta-review from a set of individual reviews
|
399 |
+
review_text = ""
|
400 |
+
for i, r in enumerate(reviews):
|
401 |
+
review_text += f"""
|
402 |
+
Review {i + 1}/{len(reviews)}:
|
403 |
+
```
|
404 |
+
{json.dumps(r)}
|
405 |
+
```
|
406 |
+
"""
|
407 |
+
base_prompt = neurips_form + review_text
|
408 |
+
|
409 |
+
llm_review, msg_history = get_response_from_llm(
|
410 |
+
base_prompt,
|
411 |
+
model=model,
|
412 |
+
client=client,
|
413 |
+
system_message=meta_reviewer_system_prompt.format(reviewer_count=len(reviews)),
|
414 |
+
print_debug=False,
|
415 |
+
msg_history=None,
|
416 |
+
temperature=temperature,
|
417 |
+
)
|
418 |
+
meta_review = format_llm_review_json(llm_review)
|
419 |
+
return meta_review
|
420 |
+
|
421 |
+
|
422 |
+
def perform_improvement(review, coder):
|
423 |
+
improvement_prompt = '''The following review has been created for your research paper:
|
424 |
+
"""
|
425 |
+
{review}
|
426 |
+
"""
|
427 |
+
|
428 |
+
Improve the text using the review.'''.format(
|
429 |
+
review=json.dumps(review)
|
430 |
+
)
|
431 |
+
coder_out = coder.run(improvement_prompt)
|
ai_scientist/perform_writeup.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import re
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
|
10 |
+
from strictjson import strict_json
|
11 |
+
|
12 |
+
from ai_scientist.generate_ideas import search_for_papers
|
13 |
+
from ai_scientist.llm import (
|
14 |
+
allchoices,
|
15 |
+
extract_json_between_markers,
|
16 |
+
get_response_from_llm,
|
17 |
+
llm_json_auto_correct,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def format_citation_first_json(text):
|
22 |
+
res = strict_json(
|
23 |
+
system_prompt="You are a JSON formatter",
|
24 |
+
user_prompt=text,
|
25 |
+
return_as_json=True,
|
26 |
+
output_format={
|
27 |
+
"Description": "A precise description of the required edit, along with the proposed text and location where it should be made",
|
28 |
+
"Query": "The search query to find the paper (e.g. attention is all you need)",
|
29 |
+
},
|
30 |
+
llm=llm_json_auto_correct,
|
31 |
+
)
|
32 |
+
text = json.loads(res)
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def format_citation_second_json(text):
|
37 |
+
res = strict_json(
|
38 |
+
system_prompt="You are a JSON formatter",
|
39 |
+
user_prompt=text,
|
40 |
+
return_as_json=True,
|
41 |
+
output_format={
|
42 |
+
"Selected": "A list of the indices of the selected papers to be cited, e.g. '[0, 1]'. Can be '[]' if no papers are selected. This must be a string",
|
43 |
+
"Description": "Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex",
|
44 |
+
},
|
45 |
+
llm=llm_json_auto_correct,
|
46 |
+
)
|
47 |
+
text = json.loads(res)
|
48 |
+
return text
|
49 |
+
|
50 |
+
|
51 |
+
# GENERATE LATEX
|
52 |
+
def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_corrections=5):
|
53 |
+
folder = osp.abspath(folder_name)
|
54 |
+
cwd = osp.join(folder, "latex") # Fixed potential issue with path
|
55 |
+
writeup_file = osp.join(cwd, "template.tex")
|
56 |
+
|
57 |
+
# Check all references are valid and in the references.bib file
|
58 |
+
with open(writeup_file, "r") as f:
|
59 |
+
tex_text = f.read()
|
60 |
+
cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
|
61 |
+
references_bib = re.search(
|
62 |
+
r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
|
63 |
+
tex_text,
|
64 |
+
re.DOTALL,
|
65 |
+
)
|
66 |
+
if references_bib is None:
|
67 |
+
print("No references.bib found in template.tex")
|
68 |
+
return
|
69 |
+
bib_text = references_bib.group(1)
|
70 |
+
cites = [cite.strip() for item in cites for cite in item.split(",")]
|
71 |
+
for cite in cites:
|
72 |
+
if cite not in bib_text:
|
73 |
+
print(f"Reference {cite} not found in references.")
|
74 |
+
prompt = f"""Reference {cite} not found in references.bib. Is this included under a different name?
|
75 |
+
If so, please modify the citation in template.tex to match the name in references.bib at the top. Otherwise, remove the cite."""
|
76 |
+
coder.run(prompt)
|
77 |
+
|
78 |
+
# Check all included figures are actually in the directory.
|
79 |
+
with open(writeup_file, "r") as f:
|
80 |
+
tex_text = f.read()
|
81 |
+
referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
|
82 |
+
all_figs = [f for f in os.listdir(folder) if f.endswith(".png")]
|
83 |
+
for figure in referenced_figs:
|
84 |
+
if figure not in all_figs:
|
85 |
+
print(f"Figure {figure} not found in directory.")
|
86 |
+
prompt = f"""The image {figure} not found in the directory. The images in the directory are: {all_figs}.
|
87 |
+
Please ensure that the figure is in the directory and that the filename is correct. Check the notes to see what each figure contains."""
|
88 |
+
coder.run(prompt)
|
89 |
+
|
90 |
+
# Remove duplicate figures.
|
91 |
+
with open(writeup_file, "r") as f:
|
92 |
+
tex_text = f.read()
|
93 |
+
referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
|
94 |
+
duplicates = {x for x in referenced_figs if referenced_figs.count(x) > 1}
|
95 |
+
if duplicates:
|
96 |
+
for dup in duplicates:
|
97 |
+
print(f"Duplicate figure found: {dup}.")
|
98 |
+
prompt = f"""Duplicate figures found: {dup}. Ensure any figure is only included once.
|
99 |
+
If duplicated, identify the best location for the figure and remove any other."""
|
100 |
+
coder.run(prompt)
|
101 |
+
|
102 |
+
# Remove duplicate section headers.
|
103 |
+
with open(writeup_file, "r") as f:
|
104 |
+
tex_text = f.read()
|
105 |
+
sections = re.findall(r"\\section{([^}]*)}", tex_text)
|
106 |
+
duplicates = {x for x in sections if sections.count(x) > 1}
|
107 |
+
if duplicates:
|
108 |
+
for dup in duplicates:
|
109 |
+
print(f"Duplicate section header found: {dup}")
|
110 |
+
prompt = f"""Duplicate section header found: {dup}. Ensure any section header is declared once.
|
111 |
+
If duplicated, identify the best location for the section header and remove any other."""
|
112 |
+
coder.run(prompt)
|
113 |
+
|
114 |
+
# Iteratively fix any LaTeX bugs
|
115 |
+
for i in range(num_error_corrections):
|
116 |
+
# Filter trivial bugs in chktex
|
117 |
+
check_output = os.popen(f"chktex {writeup_file} -q -n2 -n24 -n13 -n1").read()
|
118 |
+
if check_output:
|
119 |
+
prompt = f"""Please fix the following LaTeX errors in `template.tex` guided by the output of `chktek`:
|
120 |
+
{check_output}.
|
121 |
+
|
122 |
+
Make the minimal fix required and do not remove or change any packages.
|
123 |
+
Pay attention to any accidental uses of HTML syntax, e.g. </end instead of \\end.
|
124 |
+
"""
|
125 |
+
coder.run(prompt)
|
126 |
+
else:
|
127 |
+
break
|
128 |
+
compile_latex(cwd, pdf_file, timeout=timeout)
|
129 |
+
|
130 |
+
|
131 |
+
def compile_latex(cwd, pdf_file, timeout=30):
|
132 |
+
print("GENERATING LATEX")
|
133 |
+
|
134 |
+
commands = [
|
135 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
136 |
+
["bibtex", "template"],
|
137 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
138 |
+
["pdflatex", "-interaction=nonstopmode", "template.tex"],
|
139 |
+
]
|
140 |
+
|
141 |
+
for command in commands:
|
142 |
+
try:
|
143 |
+
result = subprocess.run(
|
144 |
+
command,
|
145 |
+
cwd=cwd,
|
146 |
+
stdout=subprocess.PIPE,
|
147 |
+
stderr=subprocess.PIPE,
|
148 |
+
text=True,
|
149 |
+
timeout=timeout,
|
150 |
+
)
|
151 |
+
print("Standard Output:\n", result.stdout)
|
152 |
+
print("Standard Error:\n", result.stderr)
|
153 |
+
except subprocess.TimeoutExpired:
|
154 |
+
print(f"Latex timed out after {timeout} seconds")
|
155 |
+
except subprocess.CalledProcessError as e:
|
156 |
+
print(f"Error running command {' '.join(command)}: {e}")
|
157 |
+
|
158 |
+
print("FINISHED GENERATING LATEX")
|
159 |
+
|
160 |
+
# Attempt to move the PDF to the desired location
|
161 |
+
try:
|
162 |
+
shutil.move(osp.join(cwd, "template.pdf"), pdf_file)
|
163 |
+
except FileNotFoundError:
|
164 |
+
print("Failed to rename PDF.")
|
165 |
+
|
166 |
+
|
167 |
+
per_section_tips = {
|
168 |
+
"Abstract": """
|
169 |
+
- TL;DR of the paper
|
170 |
+
- What are we trying to do and why is it relevant?
|
171 |
+
- Why is this hard?
|
172 |
+
- How do we solve it (i.e. our contribution!)
|
173 |
+
- How do we verify that we solved it (e.g. Experiments and results)
|
174 |
+
|
175 |
+
Please make sure the abstract reads smoothly and is well-motivated. This should be one continuous paragraph with no breaks between the lines.
|
176 |
+
""",
|
177 |
+
"Introduction": """
|
178 |
+
- Longer version of the Abstract, i.e. of the entire paper
|
179 |
+
- What are we trying to do and why is it relevant?
|
180 |
+
- Why is this hard?
|
181 |
+
- How do we solve it (i.e. our contribution!)
|
182 |
+
- How do we verify that we solved it (e.g. Experiments and results)
|
183 |
+
- New trend: specifically list your contributions as bullet points
|
184 |
+
- Extra space? Future work!
|
185 |
+
""",
|
186 |
+
"Related Work": """
|
187 |
+
- Academic siblings of our work, i.e. alternative attempts in literature at trying to solve the same problem.
|
188 |
+
- Goal is to “Compare and contrast” - how does their approach differ in either assumptions or method? If their method is applicable to our Problem Setting I expect a comparison in the experimental section. If not, there needs to be a clear statement why a given method is not applicable.
|
189 |
+
- Note: Just describing what another paper is doing is not enough. We need to compare and contrast.
|
190 |
+
""",
|
191 |
+
"Background": """
|
192 |
+
- Academic Ancestors of our work, i.e. all concepts and prior work that are required for understanding our method.
|
193 |
+
- Usually includes a subsection, Problem Setting, which formally introduces the problem setting and notation (Formalism) for our method. Highlights any specific assumptions that are made that are unusual.
|
194 |
+
- Note: If our paper introduces a novel problem setting as part of its contributions, it's best to have a separate Section.
|
195 |
+
""",
|
196 |
+
"Method": """
|
197 |
+
- What we do. Why we do it. All described using the general Formalism introduced in the Problem Setting and building on top of the concepts / foundations introduced in Background.
|
198 |
+
""",
|
199 |
+
"Experimental Setup": """
|
200 |
+
- How do we test that our stuff works? Introduces a specific instantiation of the Problem Setting and specific implementation details of our Method for this Problem Setting.
|
201 |
+
- Do not imagine unknown hardware details.
|
202 |
+
- Includes a description of the dataset, evaluation metrics, important hyperparameters, and implementation details.
|
203 |
+
""",
|
204 |
+
"Results": """
|
205 |
+
- Shows the results of running Method on our problem described in Experimental Setup.
|
206 |
+
- Includes statements on hyperparameters and other potential issues of fairness.
|
207 |
+
- Only includes results that have actually been run and saved in the logs. Do not hallucinate results that don't exist.
|
208 |
+
- If results exist: compares to baselines and includes statistics and confidence intervals.
|
209 |
+
- If results exist: includes ablation studies to show that specific parts of the method are relevant.
|
210 |
+
- Discusses limitations of the method.
|
211 |
+
- Make sure to include all the results from the experiments, and include all relevant figures.
|
212 |
+
""",
|
213 |
+
"Conclusion": """
|
214 |
+
- Brief recap of the entire paper.
|
215 |
+
- To keep going with the analogy, you can think of future work as (potential) academic offspring.
|
216 |
+
""",
|
217 |
+
}
|
218 |
+
|
219 |
+
error_list = """- Unenclosed math symbols
|
220 |
+
- Only reference figures that exist in our directory
|
221 |
+
- LaTeX syntax errors
|
222 |
+
- Numerical results that do not come from explicit experiments and logs
|
223 |
+
- Repeatedly defined figure labels
|
224 |
+
- References to papers that are not in the .bib file, DO NOT ADD ANY NEW CITATIONS!
|
225 |
+
- Unnecessary verbosity or repetition, unclear text
|
226 |
+
- Results or insights in the `notes.txt` that have not yet need included
|
227 |
+
- Any relevant figures that have not yet been included in the text
|
228 |
+
- Closing any \\begin{{figure}} with a \\end{{figure}} and \\begin{{table}} with a \\end{{table}}, etc.
|
229 |
+
- Duplicate headers, e.g. duplicated \\section{{Introduction}} or \\end{{document}}
|
230 |
+
- Unescaped symbols, e.g. shakespeare_char should be shakespeare\\_char in text
|
231 |
+
- Incorrect closing of environments, e.g. </end{{figure}}> instead of \\end{{figure}}
|
232 |
+
"""
|
233 |
+
|
234 |
+
refinement_prompt = (
|
235 |
+
"""Great job! Now criticize and refine only the {section} that you just wrote.
|
236 |
+
Make this complete in this pass, do not leave any placeholders.
|
237 |
+
|
238 |
+
Pay particular attention to fixing any errors such as:
|
239 |
+
"""
|
240 |
+
+ error_list
|
241 |
+
)
|
242 |
+
|
243 |
+
second_refinement_prompt = (
|
244 |
+
"""Criticize and refine the {section} only. Recall the advice:
|
245 |
+
{tips}
|
246 |
+
Make this complete in this pass, do not leave any placeholders.
|
247 |
+
|
248 |
+
Pay attention to how it fits in with the rest of the paper.
|
249 |
+
Identify any redundancies (e.g. repeated figures or repeated text), if there are any, decide where in the paper things should be cut.
|
250 |
+
Identify where we can save space, and be more concise without weakening the message of the text.
|
251 |
+
Fix any remaining errors as before:
|
252 |
+
"""
|
253 |
+
+ error_list
|
254 |
+
)
|
255 |
+
|
256 |
+
# CITATION HELPERS
|
257 |
+
citation_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
|
258 |
+
You have already written an initial draft of the paper and now you are looking to add missing citations to related papers throughout the paper.
|
259 |
+
The related work section already has some initial comments on which papers to add and discuss.
|
260 |
+
|
261 |
+
Focus on completing the existing write-up and do not add entirely new elements unless necessary.
|
262 |
+
Ensure every point in the paper is substantiated with sufficient evidence.
|
263 |
+
Feel free to add more cites to a particular point if there is only one or two references.
|
264 |
+
Ensure no paper is cited without a corresponding reference in the `references.bib` file.
|
265 |
+
Ensure each paragraph of the related work has sufficient background, e.g. a few papers cited.
|
266 |
+
You will be given access to the Semantic Scholar API, only add citations that you have found using the API.
|
267 |
+
Aim to discuss a broad range of relevant papers, not just the most popular ones.
|
268 |
+
Make sure not to copy verbatim from prior literature to avoid plagiarism.
|
269 |
+
|
270 |
+
You will be prompted to give a precise description of where and how to add the cite, and a search query for the paper to be cited.
|
271 |
+
Finally, you will select the most relevant cite from the search results (top 10 results will be shown).
|
272 |
+
You will have {total_rounds} rounds to add to the references, but do not need to use them all.
|
273 |
+
|
274 |
+
DO NOT ADD A CITATION THAT ALREADY EXISTS!"""
|
275 |
+
|
276 |
+
citation_first_prompt = '''Round {current_round}/{total_rounds}:
|
277 |
+
|
278 |
+
You have written this LaTeX draft so far:
|
279 |
+
|
280 |
+
"""
|
281 |
+
{draft}
|
282 |
+
"""
|
283 |
+
|
284 |
+
Identify the most important citation that you still need to add, and the query to find the paper.
|
285 |
+
|
286 |
+
Respond in the following format:
|
287 |
+
|
288 |
+
THOUGHT:
|
289 |
+
<THOUGHT>
|
290 |
+
|
291 |
+
RESPONSE:
|
292 |
+
```json
|
293 |
+
<JSON>
|
294 |
+
```
|
295 |
+
|
296 |
+
In <THOUGHT>, first briefly reason over the paper and identify where citations should be added.
|
297 |
+
If no more citations are needed, add "No more citations needed" to your thoughts.
|
298 |
+
Do not add "No more citations needed" if you are adding citations this round.
|
299 |
+
|
300 |
+
In <JSON>, respond in JSON format with the following fields:
|
301 |
+
- "Description": A precise description of the required edit, along with the proposed text and location where it should be made.
|
302 |
+
- "Query": The search query to find the paper (e.g. attention is all you need).
|
303 |
+
|
304 |
+
Ensure the description is sufficient to make the change without further context. Someone else will make the change.
|
305 |
+
The query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
|
306 |
+
This JSON will be automatically parsed, so ensure the format is precise.'''
|
307 |
+
|
308 |
+
citation_second_prompt = """Search has recovered the following articles:
|
309 |
+
|
310 |
+
{papers}
|
311 |
+
|
312 |
+
Respond in the following format:
|
313 |
+
|
314 |
+
THOUGHT:
|
315 |
+
<THOUGHT>
|
316 |
+
|
317 |
+
RESPONSE:
|
318 |
+
```json
|
319 |
+
<JSON>
|
320 |
+
```
|
321 |
+
|
322 |
+
In <THOUGHT>, first briefly reason over the search results and identify which citation best fits your paper and the location is to be added at.
|
323 |
+
If none are appropriate, add "Do not add any" to your thoughts.
|
324 |
+
|
325 |
+
In <JSON>, respond in JSON format with the following fields:
|
326 |
+
- "Selected": A list of the indices of the selected papers to be cited, e.g. "[0, 1]". Can be "[]" if no papers are selected. This must be a string.
|
327 |
+
- "Description": Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex!!!
|
328 |
+
|
329 |
+
Do not select papers that are already in the `references.bib` file at the top of the draft, or if the same citation exists under a different name.
|
330 |
+
This JSON will be automatically parsed, so ensure the format is precise."""
|
331 |
+
|
332 |
+
|
333 |
+
def get_citation_aider_prompt(
|
334 |
+
client, model, draft, current_round, total_rounds
|
335 |
+
) -> Tuple[Optional[str], bool]:
|
336 |
+
msg_history = []
|
337 |
+
try:
|
338 |
+
text, msg_history = get_response_from_llm(
|
339 |
+
citation_first_prompt.format(
|
340 |
+
draft=draft, current_round=current_round, total_rounds=total_rounds
|
341 |
+
),
|
342 |
+
client=client,
|
343 |
+
model=model,
|
344 |
+
system_message=citation_system_msg.format(total_rounds=total_rounds),
|
345 |
+
msg_history=msg_history,
|
346 |
+
)
|
347 |
+
if "No more citations needed" in text:
|
348 |
+
print("No more citations needed.")
|
349 |
+
return None, True
|
350 |
+
|
351 |
+
## PARSE OUTPUT
|
352 |
+
json_output = format_citation_first_json(text)
|
353 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
354 |
+
query = json_output["Query"]
|
355 |
+
papers = search_for_papers(query)
|
356 |
+
except Exception as e:
|
357 |
+
print(f"Error: {e}")
|
358 |
+
return None, False
|
359 |
+
|
360 |
+
if papers is None:
|
361 |
+
print("No papers found.")
|
362 |
+
return None, False
|
363 |
+
|
364 |
+
paper_strings = []
|
365 |
+
for i, paper in enumerate(papers):
|
366 |
+
paper_strings.append(
|
367 |
+
"""{i}: {title}. {authors}. {venue}, {year}.\nAbstract: {abstract}""".format(
|
368 |
+
i=i,
|
369 |
+
title=paper["title"],
|
370 |
+
authors=paper["authors"],
|
371 |
+
venue=paper["venue"],
|
372 |
+
year=paper["year"],
|
373 |
+
abstract=paper["abstract"],
|
374 |
+
)
|
375 |
+
)
|
376 |
+
papers_str = "\n\n".join(paper_strings)
|
377 |
+
|
378 |
+
try:
|
379 |
+
text, msg_history = get_response_from_llm(
|
380 |
+
citation_second_prompt.format(
|
381 |
+
papers=papers_str,
|
382 |
+
current_round=current_round,
|
383 |
+
total_rounds=total_rounds,
|
384 |
+
),
|
385 |
+
client=client,
|
386 |
+
model=model,
|
387 |
+
system_message=citation_system_msg.format(total_rounds=total_rounds),
|
388 |
+
msg_history=msg_history,
|
389 |
+
)
|
390 |
+
if "Do not add any" in text:
|
391 |
+
print("Do not add any.")
|
392 |
+
return None, False
|
393 |
+
## PARSE OUTPUT
|
394 |
+
json_output = format_citation_second_json(text)
|
395 |
+
assert json_output is not None, "Failed to extract JSON from LLM output"
|
396 |
+
desc = json_output["Description"]
|
397 |
+
selected_papers = json_output["Selected"]
|
398 |
+
selected_papers = str(selected_papers)
|
399 |
+
|
400 |
+
# convert to list
|
401 |
+
if selected_papers != "[]":
|
402 |
+
selected_papers = list(map(int, selected_papers.strip("[]").split(",")))
|
403 |
+
assert all(
|
404 |
+
[0 <= i < len(papers) for i in selected_papers]
|
405 |
+
), "Invalid paper index"
|
406 |
+
bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers]
|
407 |
+
bibtex_string = "\n".join(bibtexs)
|
408 |
+
else:
|
409 |
+
return None, False
|
410 |
+
|
411 |
+
except Exception as e:
|
412 |
+
print(f"Error: {e}")
|
413 |
+
return None, False
|
414 |
+
|
415 |
+
# Add citation to draft
|
416 |
+
aider_format = '''The following citations have just been added to the end of the `references.bib` file definition at the top of the file:
|
417 |
+
"""
|
418 |
+
{bibtex}
|
419 |
+
"""
|
420 |
+
You do not need to add them yourself.
|
421 |
+
ABSOLUTELY DO NOT ADD IT AGAIN!!!
|
422 |
+
|
423 |
+
Make the proposed change to the draft incorporating these new cites:
|
424 |
+
{description}
|
425 |
+
|
426 |
+
Use your judgment for whether these should be cited anywhere else.
|
427 |
+
Make sure that any citation precisely matches the name in `references.bib`. Change its name to the correct name in the bibtex if needed.
|
428 |
+
Ensure the citation is well-integrated into the text.'''
|
429 |
+
|
430 |
+
aider_prompt = (
|
431 |
+
aider_format.format(bibtex=bibtex_string, description=desc)
|
432 |
+
+ """\n You must use \cite or \citet to reference papers, do not manually type out author names."""
|
433 |
+
)
|
434 |
+
return aider_prompt, False
|
435 |
+
|
436 |
+
|
437 |
+
# PERFORM WRITEUP
|
438 |
+
def perform_writeup(
|
439 |
+
idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20
|
440 |
+
):
|
441 |
+
# CURRENTLY ASSUMES LATEX
|
442 |
+
abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
|
443 |
+
|
444 |
+
First, please fill in the "Title" and "Abstract" sections of the writeup.
|
445 |
+
|
446 |
+
Some tips are provided below:
|
447 |
+
{per_section_tips["Abstract"]}
|
448 |
+
|
449 |
+
Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
|
450 |
+
|
451 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
452 |
+
"""
|
453 |
+
coder_out = coder.run(abstract_prompt)
|
454 |
+
coder_out = coder.run(
|
455 |
+
refinement_prompt.format(section="Abstract")
|
456 |
+
.replace(r"{{", "{")
|
457 |
+
.replace(r"}}", "}")
|
458 |
+
)
|
459 |
+
for section in [
|
460 |
+
"Introduction",
|
461 |
+
"Background",
|
462 |
+
"Method",
|
463 |
+
"Experimental Setup",
|
464 |
+
"Results",
|
465 |
+
"Conclusion",
|
466 |
+
]:
|
467 |
+
section_prompt = f"""Please fill in the {section} of the writeup. Some tips are provided below:
|
468 |
+
{per_section_tips[section]}
|
469 |
+
|
470 |
+
Be sure to use \cite or \citet where relevant, referring to the works provided in the file.
|
471 |
+
Do not cite anything that is not already in `references.bib`. Do not add any new entries to this.
|
472 |
+
|
473 |
+
Keep the experimental results (figures and tables) only in the Results section, and make sure that any captions are filled in.
|
474 |
+
In this pass, do not reference anything in later sections of the paper.
|
475 |
+
|
476 |
+
Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
|
477 |
+
|
478 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
479 |
+
"""
|
480 |
+
coder_out = coder.run(section_prompt)
|
481 |
+
coder_out = coder.run(
|
482 |
+
refinement_prompt.format(section=section)
|
483 |
+
.replace(r"{{", "{")
|
484 |
+
.replace(r"}}", "}")
|
485 |
+
)
|
486 |
+
|
487 |
+
# SKETCH THE RELATED WORK
|
488 |
+
section_prompt = f"""Please fill in the Related Work of the writeup. Some tips are provided below:
|
489 |
+
|
490 |
+
{per_section_tips["Related Work"]}
|
491 |
+
|
492 |
+
For this section, very briefly sketch out the structure of the section, and clearly indicate what papers you intend to include.
|
493 |
+
Do this all in LaTeX comments using %.
|
494 |
+
The related work should be concise, only plan to discuss the most relevant work.
|
495 |
+
Do not modify `references.bib` to add any new citations, this will be filled in at a later stage.
|
496 |
+
|
497 |
+
Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
|
498 |
+
"""
|
499 |
+
coder_out = coder.run(section_prompt)
|
500 |
+
|
501 |
+
# Fill paper with cites.
|
502 |
+
for _ in range(num_cite_rounds):
|
503 |
+
with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
|
504 |
+
draft = f.read()
|
505 |
+
prompt, done = get_citation_aider_prompt(
|
506 |
+
cite_client, cite_model, draft, _, num_cite_rounds
|
507 |
+
)
|
508 |
+
if done:
|
509 |
+
break
|
510 |
+
if prompt is not None:
|
511 |
+
# extract bibtex string
|
512 |
+
bibtex_string = prompt.split('"""')[1]
|
513 |
+
# insert this into draft before the "\end{filecontents}" line
|
514 |
+
search_str = r"\end{filecontents}"
|
515 |
+
draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
|
516 |
+
with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
|
517 |
+
f.write(draft)
|
518 |
+
coder_out = coder.run(prompt)
|
519 |
+
|
520 |
+
coder_out = coder.run(
|
521 |
+
refinement_prompt.format(section="Related Work")
|
522 |
+
.replace(r"{{", "{")
|
523 |
+
.replace(r"}}", "}")
|
524 |
+
)
|
525 |
+
|
526 |
+
## SECOND REFINEMENT LOOP
|
527 |
+
coder.run(
|
528 |
+
"""Great job! Now that there is a complete draft of the entire paper, let's refine each section again.
|
529 |
+
First, re-think the Title if necessary. Keep this concise and descriptive of the paper's concept, but try by creative with it."""
|
530 |
+
)
|
531 |
+
for section in [
|
532 |
+
"Abstract",
|
533 |
+
"Related Work",
|
534 |
+
"Introduction",
|
535 |
+
"Background",
|
536 |
+
"Method",
|
537 |
+
"Experimental Setup",
|
538 |
+
"Results",
|
539 |
+
"Conclusion",
|
540 |
+
]:
|
541 |
+
coder_out = coder.run(
|
542 |
+
second_refinement_prompt.format(
|
543 |
+
section=section, tips=per_section_tips[section]
|
544 |
+
)
|
545 |
+
.replace(r"{{", "{")
|
546 |
+
.replace(r"}}", "}")
|
547 |
+
)
|
548 |
+
|
549 |
+
generate_latex(coder, folder_name, f"{folder_name}/{idea['Name']}.pdf")
|
550 |
+
|
551 |
+
|
552 |
+
if __name__ == "__main__":
|
553 |
+
import json
|
554 |
+
|
555 |
+
from aider.coders import Coder
|
556 |
+
from aider.io import InputOutput
|
557 |
+
from aider.models import Model
|
558 |
+
|
559 |
+
parser = argparse.ArgumentParser(description="Perform writeup for a project")
|
560 |
+
parser.add_argument("--folder", type=str)
|
561 |
+
parser.add_argument("--no-writing", action="store_true", help="Only generate")
|
562 |
+
parser.add_argument(
|
563 |
+
"--model",
|
564 |
+
type=str,
|
565 |
+
default="gpt-4o-2024-05-13",
|
566 |
+
choices=allchoices,
|
567 |
+
help="Model to use for AI Scientist.",
|
568 |
+
)
|
569 |
+
args = parser.parse_args()
|
570 |
+
if args.model == "claude-3-5-sonnet-20240620":
|
571 |
+
import anthropic
|
572 |
+
|
573 |
+
print(f"Using Anthropic API with model {args.model}.")
|
574 |
+
client_model = "claude-3-5-sonnet-20240620"
|
575 |
+
client = anthropic.Anthropic()
|
576 |
+
elif args.model.startswith("bedrock") and "claude" in args.model:
|
577 |
+
import anthropic
|
578 |
+
|
579 |
+
# Expects: bedrock/<MODEL_ID>
|
580 |
+
client_model = args.model.split("/")[-1]
|
581 |
+
|
582 |
+
print(f"Using Amazon Bedrock with model {client_model}.")
|
583 |
+
client = anthropic.AnthropicBedrock()
|
584 |
+
elif args.model.startswith("vertex_ai") and "claude" in args.model:
|
585 |
+
import anthropic
|
586 |
+
|
587 |
+
# Expects: vertex_ai/<MODEL_ID>
|
588 |
+
client_model = args.model.split("/")[-1]
|
589 |
+
|
590 |
+
print(f"Using Vertex AI with model {client_model}.")
|
591 |
+
client = anthropic.AnthropicVertex()
|
592 |
+
elif args.model == "gpt-4o-2024-05-13":
|
593 |
+
import openai
|
594 |
+
|
595 |
+
print(f"Using OpenAI API with model {args.model}.")
|
596 |
+
client_model = "gpt-4o-2024-05-13"
|
597 |
+
client = openai.OpenAI()
|
598 |
+
elif args.model == "deepseek-coder-v2-0724":
|
599 |
+
import openai
|
600 |
+
|
601 |
+
print(f"Using OpenAI API with {args.model}.")
|
602 |
+
client_model = "deepseek-coder-v2-0724"
|
603 |
+
client = openai.OpenAI(
|
604 |
+
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
|
605 |
+
)
|
606 |
+
|
607 |
+
# ----------------------------------------------------
|
608 |
+
|
609 |
+
elif args.model == "Qwen/Qwen2.5-72B-Instruct":
|
610 |
+
# elif args.model.startswith("hyperbolic"):
|
611 |
+
print(f"Welcome to the PARADISE of debug <launch_scientist.py> {args.model}.")
|
612 |
+
|
613 |
+
import openai
|
614 |
+
import os
|
615 |
+
# client_model = args.model[11:]
|
616 |
+
client_model = args.model
|
617 |
+
client = openai.OpenAI(
|
618 |
+
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1"
|
619 |
+
)
|
620 |
+
# ----------------------------------------------------
|
621 |
+
elif args.model == "llama3.1-405b":
|
622 |
+
import openai
|
623 |
+
|
624 |
+
print(f"Using OpenAI API with {args.model}.")
|
625 |
+
client_model = "meta-llama/llama-3.1-405b-instruct"
|
626 |
+
client = openai.OpenAI(
|
627 |
+
api_key=os.environ["OPENROUTER_API_KEY"],
|
628 |
+
base_url="https://openrouter.ai/api/v1",
|
629 |
+
)
|
630 |
+
|
631 |
+
elif args.model.startswith("ollama"):
|
632 |
+
import openai
|
633 |
+
|
634 |
+
print(f"Using Ollama with {args.model}.")
|
635 |
+
client_model = args.model.split("/")[-1]
|
636 |
+
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
637 |
+
else:
|
638 |
+
raise ValueError(f"Model {args.model} not recognized.")
|
639 |
+
|
640 |
+
|
641 |
+
print("Make sure you cleaned the Aider logs if re-generating the writeup!")
|
642 |
+
folder_name = args.folder
|
643 |
+
idea_name = osp.basename(folder_name)
|
644 |
+
exp_file = osp.join(folder_name, "experiment.py")
|
645 |
+
vis_file = osp.join(folder_name, "plot.py")
|
646 |
+
notes = osp.join(folder_name, "notes.txt")
|
647 |
+
|
648 |
+
model = args.model
|
649 |
+
|
650 |
+
writeup_file = osp.join(folder_name, "latex", "template.tex")
|
651 |
+
ideas_file = osp.join(folder_name, "ideas.json")
|
652 |
+
with open(ideas_file, "r") as f:
|
653 |
+
ideas = json.load(f)
|
654 |
+
for idea in ideas:
|
655 |
+
if idea["Name"] in idea_name:
|
656 |
+
print(f"Found idea: {idea['Name']}")
|
657 |
+
break
|
658 |
+
if idea["Name"] not in idea_name:
|
659 |
+
raise ValueError(f"Idea {idea_name} not found")
|
660 |
+
fnames = [exp_file, writeup_file, notes]
|
661 |
+
io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
|
662 |
+
|
663 |
+
|
664 |
+
|
665 |
+
# AIDER CHAT INITIALIZATION CODE
|
666 |
+
|
667 |
+
if args.model == "deepseek-ai/DeepSeek-V2.5":
|
668 |
+
print("aider chosen deepseek")
|
669 |
+
main_model = Model("deepseek-chat")
|
670 |
+
|
671 |
+
elif args.model == "deepseek-coder-v2-0724":
|
672 |
+
main_model = Model("deepseek-ai/DeepSeek-V2.5")
|
673 |
+
|
674 |
+
elif args.model == "llama3.1-405b":
|
675 |
+
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
|
676 |
+
|
677 |
+
# ----------------------------------------------------
|
678 |
+
|
679 |
+
elif args.model == "hyperbolic/Qwen/Qwen2.5-72B-Instruct":
|
680 |
+
print("aider model chosen")
|
681 |
+
# main_model = Model("fireworks_ai/accounts/fireworks/models/qwen2-72b-instruct")
|
682 |
+
main_model = Model("hyperbolic/Qwen/Qwen2.5-72B-Instruct")
|
683 |
+
|
684 |
+
elif args.model == "hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct":
|
685 |
+
main_model = Model("hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct")
|
686 |
+
|
687 |
+
# ----------------------------------------------------
|
688 |
+
|
689 |
+
else:
|
690 |
+
print("hello world")
|
691 |
+
main_model = Model(model)
|
692 |
+
|
693 |
+
coder = Coder.create(
|
694 |
+
main_model=main_model,
|
695 |
+
fnames=fnames,
|
696 |
+
io=io,
|
697 |
+
stream=False,
|
698 |
+
use_git=False,
|
699 |
+
edit_format="diff",
|
700 |
+
)
|
701 |
+
if args.no_writing:
|
702 |
+
generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
|
703 |
+
else:
|
704 |
+
try:
|
705 |
+
perform_writeup(idea, folder_name, coder, client, client_model)
|
706 |
+
except Exception as e:
|
707 |
+
print(f"Failed to perform writeup: {e}")
|
cuda-keyring_1.0-1_all.deb
ADDED
Binary file (4.33 kB). View file
|
|
cuda-keyring_1.0-1_all.deb.1
ADDED
Binary file (4.33 kB). View file
|
|
cuda-repo-ubuntu2004-11-0-local_11.0.3-450.51.06-1_amd64.deb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f03712f53fbe7a39a8d82f8ba186e9dc44fece31b8b2230dded07153a57a27dc
|
3 |
+
size 2306981744
|
data/enwik8/enwik8
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b49720ec4d78c3c9fabaee6e4179a5e997302b3a70029f30f2d582218c024a8
|
3 |
+
size 100000000
|
data/enwik8/enwik8.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:547994d9980ebed1288380d652999f38a14fe291a6247c157c3d33d4932534bc
|
3 |
+
size 36445475
|
data/enwik8/meta.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e6ccdb5f07974d4d462bf8acbc1b5a7225f4da63453df9f67dd7bf6d976c386
|
3 |
+
size 2211
|
data/enwik8/prepare.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Prepare the enwik8 dataset for character-level language modeling.
|
3 |
+
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
|
4 |
+
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
|
5 |
+
encoder and decoder and some other related info.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import requests
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# download the enwik8 dataset
|
13 |
+
input_file_path = os.path.join(os.path.dirname(__file__), 'enwik8')
|
14 |
+
if not os.path.exists(input_file_path):
|
15 |
+
data_url = 'http://mattmahoney.net/dc/enwik8.zip'
|
16 |
+
r = requests.get(data_url)
|
17 |
+
with open(os.path.join(os.path.dirname(__file__), 'enwik8.zip'), 'wb') as f:
|
18 |
+
f.write(r.content)
|
19 |
+
|
20 |
+
# unzip the enwik8 dataset
|
21 |
+
import zipfile
|
22 |
+
with zipfile.ZipFile(os.path.join(os.path.dirname(__file__), 'enwik8.zip'), 'r') as zip_ref:
|
23 |
+
zip_ref.extractall(os.path.dirname(__file__))
|
24 |
+
|
25 |
+
with open(input_file_path, 'r', encoding='latin-1') as f:
|
26 |
+
data = f.read()
|
27 |
+
print(f"length of dataset in characters: {len(data):,}")
|
28 |
+
|
29 |
+
# get all the unique characters that occur in this text
|
30 |
+
chars = sorted(list(set(data)))
|
31 |
+
vocab_size = len(chars)
|
32 |
+
print("all the unique characters:", ''.join(chars))
|
33 |
+
print(f"vocab size: {vocab_size:,}")
|
34 |
+
|
35 |
+
# create a mapping from characters to integers
|
36 |
+
stoi = { ch:i for i,ch in enumerate(chars) }
|
37 |
+
itos = { i:ch for i,ch in enumerate(chars) }
|
38 |
+
def encode(s):
|
39 |
+
return [stoi[c] for c in s] # encoder: take a string, output a list of integers
|
40 |
+
def decode(l):
|
41 |
+
return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
|
42 |
+
|
43 |
+
# create the train, validation, and test splits
|
44 |
+
n = len(data)
|
45 |
+
num_test_chars = 5000000
|
46 |
+
train_data = data[: -2 * num_test_chars]
|
47 |
+
val_data = data[-2 * num_test_chars: -num_test_chars]
|
48 |
+
test_data = data[-num_test_chars:]
|
49 |
+
|
50 |
+
# encode all splits to integers
|
51 |
+
train_ids = encode(train_data)
|
52 |
+
val_ids = encode(val_data)
|
53 |
+
test_ids = encode(test_data)
|
54 |
+
|
55 |
+
print(f"train has {len(train_ids):,} tokens")
|
56 |
+
print(f"val has {len(val_ids):,} tokens")
|
57 |
+
print(f"test has {len(test_ids):,} tokens")
|
58 |
+
|
59 |
+
# export to bin files
|
60 |
+
train_ids = np.array(train_ids, dtype=np.uint16)
|
61 |
+
val_ids = np.array(val_ids, dtype=np.uint16)
|
62 |
+
test_ids = np.array(test_ids, dtype=np.uint16)
|
63 |
+
|
64 |
+
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
|
65 |
+
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
|
66 |
+
test_ids.tofile(os.path.join(os.path.dirname(__file__), 'test.bin'))
|
67 |
+
|
68 |
+
# save the meta information as well, to help us encode/decode later
|
69 |
+
meta = {
|
70 |
+
'vocab_size': vocab_size,
|
71 |
+
'itos': itos,
|
72 |
+
'stoi': stoi,
|
73 |
+
}
|
74 |
+
with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
|
75 |
+
pickle.dump(meta, f)
|
data/enwik8/test.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1627449a3615fb097ee0dedfe5a702bd574c3084c2c69338a1819a40c0a4eb6c
|
3 |
+
size 10000000
|
data/enwik8/train.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd58505a13f96a801113b7578e886b550c8c74678f425d711559e6728c791f14
|
3 |
+
size 180000000
|
data/enwik8/val.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f4b53738dd5eacb29568913815a43415eaaa195c33a2935580dc2055e13954c
|
3 |
+
size 10000000
|