|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
|
model_path = "cognitivecompuations/Quiet-STaR-Base" |
|
|
|
n_ahead = 8 |
|
n_ahead_talk = 4 |
|
merged_talk_heads = True |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
max_thoughts=n_ahead + n_ahead_talk + 1, |
|
merged_talk_heads=merged_talk_heads, |
|
merged_lm_and_talk_heads=False, |
|
merged_lm_and_think_heads=True, |
|
use_concat_talk_head=True, |
|
use_shallow_think=True, |
|
use_shallow_talk=False, |
|
use_complex_think_head=False, |
|
use_complex_talk_head=True, |
|
use_weighted_talk_head=True, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model.tokenizer = tokenizer |
|
|
|
model.use_end_thought_token = True |
|
model.use_start_thought_token = True |
|
model.wandb_enabled = True |
|
model.n_ahead = n_ahead |
|
model.n_passes = 2 |
|
model.eval_mode = True |
|
model.first_run = False |
|
model.kill_after = 100 |
|
model.rm_initialized = True |
|
model.original_mode = False |
|
|
|
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs): |
|
with torch.no_grad(): |
|
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) |
|
for cur_token_idx in range(max_new_tokens): |
|
|
|
new_ids = model( |
|
input_ids[~finished_generating], |
|
attention_mask=attention_mask[~finished_generating] |
|
)['logits'] |
|
|
|
new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf") |
|
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): |
|
|
|
base_answer_ids = input_ids[answer_idx] |
|
new_answer_ids = new_ids[list_idx] |
|
last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() |
|
|
|
new_ids_sampled = torch.multinomial( |
|
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1) |
|
|
|
if last_token_idx + 1 >= len(base_answer_ids): |
|
|
|
new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, |
|
device=input_ids.device) |
|
input_ids = torch.cat([input_ids, new_padding], dim=-1) |
|
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) |
|
attention_mask[answer_idx, last_token_idx + 1] = 1 |
|
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled |
|
if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id: |
|
finished_generating[answer_idx] = 1 |
|
|
|
if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"): |
|
finished_generating[answer_idx] = 1 |
|
if finished_generating.all(): |
|
break |
|
streamer.put(new_ids_sampled) |
|
return input_ids, attention_mask |
|
|
|
prompt = " How would a typical person answer each of the following questions about causation? Frank T., had an ongoing dispute with his neighbor over a stretch of land and one day decided to shoot his neighbor in the body. Frank T. had no experience with guns, his hand slipped on the barrel of the gun, and the shot went wild. Nonetheless, the bullet bounced off a large boulder several feet away and hit the neighbor's body, causing significant injury. Did Frank T. intentionally shoot his neighbor in the body?" |
|
|
|
input_ids = tokenizer( |
|
prompt=prompt, |
|
return_tensors='pt' |
|
).input_ids.cuda() |
|
|
|
|
|
tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device) |
|
|
|
|
|
attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device) |
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) |
|
|
|
output_ids, _ = custom_generate( |
|
model, |
|
input_ids=tokens, |
|
attention_mask=attention_mask, |
|
max_new_tokens=512, |
|
streamer=streamer, |
|
temperature=0.9, |
|
) |
|
|
|
generated_text = "" |
|
|
|
print() |
|
|
|
|