import json from pathlib import Path import gradio as gr import torch from torch.nn import functional as F from torch.utils.data import DataLoader from common import setup_cpu from models import build_tokenizer, build_model from models.meta_optimizer import AttnOptimWrapper from tasks import load_task from tasks.loader import TokenizedForMCRightPad DISPLAY_MAPPING = { "sst2": {"positive": "Pos", "negative": "Neg"}, "trec": {}, } @torch.no_grad() def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_input): batched_choices_logprobs = [] for batched_one_choice_input in batched_choices_input: batch_input_ids, batch_attention_mask, batch_choice_start, batch_choice_end = batched_one_choice_input bs = len(batch_input_ids) merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1) # [B, #Heads, Length, Hidden] expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv] batched_logits = model( input_ids=batch_input_ids, # [B, L'] attention_mask=merged_attn_mask, # [B, L + L'] past_key_values=expand_exemplar_attn_kv, # num_layers * 2 * [B, num_heads, L, H] ).logits batched_output = F.log_softmax(batched_logits, dim=-1) # [B, L', Vocab] batched_one_choice_logprobs = [] for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output): choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1) # [L, 1] choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1] # [L, Vocab] extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1) choice_length = choice_end - choice_start lm_log_p = torch.sum(extracted).item() norm_lm_log_p = (lm_log_p / choice_length).item() choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p} batched_one_choice_logprobs.append(choice_lm_info) batched_choices_logprobs.append(batched_one_choice_logprobs) return batched_choices_logprobs @torch.no_grad() def process_once(dataset_name, exemplar_str, forward_steps, raw_data): model_name, model_size = "opt", "125m" step_size, momentum = 0.01, 0.9 setup_cpu(seed=seed) TaskHandler = load_task(dataset_name) task_agent = TaskHandler(prompt_version) tokenizer = build_tokenizer(model_name, model_size, padding_side="right") model = build_model(model_name, model_size, False) torch.autograd.set_grad_enabled(False) processed_data = task_agent.dataset_preprocess(raw_data) dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify) exemplar_input_ids, exemplar_attn_mask = dataset.tokenize_demonstration(exemplar_str) loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1) meta_optim = AttnOptimWrapper(model, model_name, step_size=step_size, momentum=momentum) meta_optim.init() for _ in range(forward_steps): exemplar_kv = meta_optim.step(exemplar_input_ids) generated_info = [] # question * [choice0_prob, choice1_prob] for batch_input in loader: batch_output = do_infer_probs(model, exemplar_kv, exemplar_attn_mask.unsqueeze(0), batch_input) # [batch_of_choice0, batch_of_choice1, ...] zipped_logprobs = list(zip(*batch_output)) # batch * (choice0, choice1, ...) generated_info.extend(zipped_logprobs) all_predicted = [] for idx, (data, choice_info) in enumerate(zip(processed_data, generated_info)): merged_choice_info = task_agent.merge_choice_info(choice_info) merged_predictions_idx = task_agent.choice_info_to_predictions(merged_choice_info)["lm_log_p"] predicted = task_agent.CHOICES[merged_predictions_idx] ground_truth = task_agent.CHOICES[data["answer_idx"]] res = f"{DISPLAY_MAPPING[dataset_name][predicted]}{'✅' if predicted == ground_truth else '❌'}" all_predicted.append(res) return all_predicted def transpose(l): return list(map(list, zip(*l))) def button_pressed(prev_state): dataset_name = prev_state["dataset_name"] exemplar_str = prev_state["exemplar_str"] forward_steps = prev_state["step"] + 2 raw_data = prev_state["raw_data"] prev_table_data = prev_state["table_data"] current_output = process_once(dataset_name, exemplar_str, forward_steps, raw_data) t_prev = transpose(prev_table_data) t_prev.append([f"T={forward_steps}"] + current_output) updated_table_data = transpose(t_prev) ret = [ { "dataset_name": dataset_name, "exemplar_str": exemplar_str, "raw_data": raw_data, "step": forward_steps, "table_data": updated_table_data, }, f"Step + 2, Now: {forward_steps}", updated_table_data, ] return ret if __name__ == "__main__": dataset_name = "sst2" seed = 0 prompt_version = "default" kv_iter = 10 print(f"Dataset: {dataset_name}") task_root = Path("example_sets").joinpath(dataset_name) with task_root.joinpath("demos.txt").open("r") as f: demos = f.read() with task_root.joinpath("sample.pkl").open("r") as f: data = json.load(f) raw_data = [data[str(i)] for i in range(len(data))] css = """ #the-table > div > div > div > table > thead {display: none}""" title = "🤔 Iterative Forward Tuning Boosts In-context Learning in Language Models" demo = gr.Blocks(css=css, title="🤔Deep-Thinking") with demo: gr.Markdown(f"

{title}

") with gr.Tab("SST-2"): mapping = ["negative", "positive"] init_columns = [[e["sentence"], f"*{DISPLAY_MAPPING['sst2'][mapping[e['label']]]}*"] for e in raw_data] state = gr.State( { "dataset_name": "sst2", "exemplar_str": demos, "raw_data": raw_data, "step": 0, "table_data": [["**Test Input**", "**Golden**"], *init_columns], } ) prompt = gr.Textbox(label="Demonstrations (Prompt template formatted)", value=demos) big_table = gr.DataFrame( value=[["**Test Input**", "**Golden**"], *init_columns], elem_id="the-table", datatype=["markdown"] * 50, headers=None, ) step_button = gr.Button("Step + 2, Now: 0") step_button.click(button_pressed, inputs=[state], outputs=[state, step_button, big_table]) demo.launch(server_name="0.0.0.0")