deep-thinking / app.py
jx-yang's picture
<ADD> +app
9d21d47
raw
history blame
6.93 kB
import json
from pathlib import Path
import gradio as gr
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from common import setup_cpu
from models import build_tokenizer, build_model
from models.meta_optimizer import AttnOptimWrapper
from tasks import load_task
from tasks.loader import TokenizedForMCRightPad
DISPLAY_MAPPING = {
"sst2": {"positive": "Pos", "negative": "Neg"},
"trec": {},
}
@torch.no_grad()
def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_input):
batched_choices_logprobs = []
for batched_one_choice_input in batched_choices_input:
batch_input_ids, batch_attention_mask, batch_choice_start, batch_choice_end = batched_one_choice_input
bs = len(batch_input_ids)
merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1)
# [B, #Heads, Length, Hidden]
expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv]
batched_logits = model(
input_ids=batch_input_ids, # [B, L']
attention_mask=merged_attn_mask, # [B, L + L']
past_key_values=expand_exemplar_attn_kv, # num_layers * 2 * [B, num_heads, L, H]
).logits
batched_output = F.log_softmax(batched_logits, dim=-1) # [B, L', Vocab]
batched_one_choice_logprobs = []
for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output):
choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1) # [L, 1]
choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1] # [L, Vocab]
extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1)
choice_length = choice_end - choice_start
lm_log_p = torch.sum(extracted).item()
norm_lm_log_p = (lm_log_p / choice_length).item()
choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p}
batched_one_choice_logprobs.append(choice_lm_info)
batched_choices_logprobs.append(batched_one_choice_logprobs)
return batched_choices_logprobs
@torch.no_grad()
def process_once(dataset_name, exemplar_str, forward_steps, raw_data):
model_name, model_size = "opt", "125m"
step_size, momentum = 0.01, 0.9
setup_cpu(seed=seed)
TaskHandler = load_task(dataset_name)
task_agent = TaskHandler(prompt_version)
tokenizer = build_tokenizer(model_name, model_size, padding_side="right")
model = build_model(model_name, model_size, False)
torch.autograd.set_grad_enabled(False)
processed_data = task_agent.dataset_preprocess(raw_data)
dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify)
exemplar_input_ids, exemplar_attn_mask = dataset.tokenize_demonstration(exemplar_str)
loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1)
meta_optim = AttnOptimWrapper(model, model_name, step_size=step_size, momentum=momentum)
meta_optim.init()
for _ in range(forward_steps):
exemplar_kv = meta_optim.step(exemplar_input_ids)
generated_info = [] # question * [choice0_prob, choice1_prob]
for batch_input in loader:
batch_output = do_infer_probs(model, exemplar_kv, exemplar_attn_mask.unsqueeze(0), batch_input) # [batch_of_choice0, batch_of_choice1, ...]
zipped_logprobs = list(zip(*batch_output)) # batch * (choice0, choice1, ...)
generated_info.extend(zipped_logprobs)
all_predicted = []
for idx, (data, choice_info) in enumerate(zip(processed_data, generated_info)):
merged_choice_info = task_agent.merge_choice_info(choice_info)
merged_predictions_idx = task_agent.choice_info_to_predictions(merged_choice_info)["lm_log_p"]
predicted = task_agent.CHOICES[merged_predictions_idx]
ground_truth = task_agent.CHOICES[data["answer_idx"]]
res = f"{DISPLAY_MAPPING[dataset_name][predicted]}{'✅' if predicted == ground_truth else '❌'}"
all_predicted.append(res)
return all_predicted
def transpose(l):
return list(map(list, zip(*l)))
def button_pressed(prev_state):
dataset_name = prev_state["dataset_name"]
exemplar_str = prev_state["exemplar_str"]
forward_steps = prev_state["step"] + 2
raw_data = prev_state["raw_data"]
prev_table_data = prev_state["table_data"]
current_output = process_once(dataset_name, exemplar_str, forward_steps, raw_data)
t_prev = transpose(prev_table_data)
t_prev.append([f"T={forward_steps}"] + current_output)
updated_table_data = transpose(t_prev)
ret = [
{
"dataset_name": dataset_name,
"exemplar_str": exemplar_str,
"raw_data": raw_data,
"step": forward_steps,
"table_data": updated_table_data,
},
f"Step + 2, Now: {forward_steps}",
updated_table_data,
]
return ret
if __name__ == "__main__":
dataset_name = "sst2"
seed = 0
prompt_version = "default"
kv_iter = 10
print(f"Dataset: {dataset_name}")
task_root = Path("example_sets").joinpath(dataset_name)
with task_root.joinpath("demos.txt").open("r") as f:
demos = f.read()
with task_root.joinpath("sample.pkl").open("r") as f:
data = json.load(f)
raw_data = [data[str(i)] for i in range(len(data))]
css = """ #the-table > div > div > div > table > thead {display: none}"""
title = "🤔 Iterative Forward Tuning Boosts In-context Learning in Language Models"
demo = gr.Blocks(css=css, title="🤔Deep-Thinking")
with demo:
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>")
with gr.Tab("SST-2"):
mapping = ["negative", "positive"]
init_columns = [[e["sentence"], f"*{DISPLAY_MAPPING['sst2'][mapping[e['label']]]}*"] for e in raw_data]
state = gr.State(
{
"dataset_name": "sst2",
"exemplar_str": demos,
"raw_data": raw_data,
"step": 0,
"table_data": [["**Test Input**", "**Golden**"], *init_columns],
}
)
prompt = gr.Textbox(label="Demonstrations (Prompt template formatted)", value=demos)
big_table = gr.DataFrame(
value=[["**Test Input**", "**Golden**"], *init_columns],
elem_id="the-table",
datatype=["markdown"] * 50,
headers=None,
)
step_button = gr.Button("Step + 2, Now: 0")
step_button.click(button_pressed, inputs=[state], outputs=[state, step_button, big_table])
demo.launch(server_name="0.0.0.0")