EAGLE / model /ea_model.py
yuhuili's picture
Update model/ea_model.py
6e01670
raw
history blame
14.2 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from .utils import *
from .kv_cache import initialize_past_key_values
from .choices import mc_sim_7b_63
from transformers import AutoTokenizer
import os
from huggingface_hub import hf_hub_download
from .cnets import Model
from .configs import EConfig
class ResBlock(nn.Module):
"""
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation,
and then adds the result to the original input, creating a residual connection.
Args:
hidden_size (int): The size of the hidden layers in the block.
"""
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
# Initialize as an identity mapping
torch.nn.init.zeros_(self.linear.weight)
# Use SiLU activation to keep consistent with the Llama model
self.act = nn.SiLU()
def forward(self, x):
"""
Forward pass of the ResBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after the residual connection and activation.
"""
return x + self.act(self.linear(x))
class EaModel(nn.Module):
def __init__(
self,
base_model,
base_model_name_or_path,
ea_model_path,
):
super().__init__()
self.base_model = base_model
self.config = base_model.config
self.hidden_size = base_model.lm_head.weight.shape[-1]
self.vocab_size = base_model.lm_head.weight.shape[0]
self.base_model_name_or_path = base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
config = EConfig.from_pretrained(ea_model_path)
self.ea_layer = Model(config)
device = base_model.model.layers[-1].self_attn.q_proj.weight.device
self.ea_layer.to(torch.float16).to(device)
self.ea_layer.init_tree()
def get_tokenizer(self):
"""Get the tokenizer of the base model.
Returns:
Tokenizer: The tokenizer of the base model.
"""
return self.tokenizer
@classmethod
def from_pretrained(
cls,
base_model_path=None,
ea_model_path=None,
**kwargs,
):
base_model = KVLlamaForCausalLM.from_pretrained(
base_model_path, **kwargs
)
model = cls(
base_model,
base_model_path,
ea_model_path
)
ea_layer_state_dict = torch.load(os.path.join(ea_model_path,"pytorch_model.bin"), map_location=base_model.device)
model.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
return model
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
past_key_values=None,
output_orig=False,
position_ids=None,
init=True,
logits_processor=None
):
with torch.inference_mode():
# Pass input through the base model
outputs = self.base_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
if output_orig:
orig = self.base_model.lm_head(outputs[0])
hidden_states = outputs[0].clone()
if init:
if logits_processor is not None:
logits=orig[:, -1]
logits=logits_processor(None,logits)
probabilities = torch.nn.functional.softmax(logits, dim=1)
token=torch.multinomial(probabilities, 1)
else:
token = torch.argmax(orig[:,-1])
token=token[None,None]
input_ids=torch.cat((input_ids,token.to(input_ids.device)),dim=1)
# Clone the output hidden states
ea_logits = self.ea_layer.topK_genrate(hidden_states,input_ids,self.base_model.lm_head,logits_processor)
if output_orig:
return ea_logits, outputs, orig,hidden_states,token
return ea_logits,hidden_states,token
else:
if output_orig:
return outputs,orig,hidden_states
@torch.no_grad()
def eagenerate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
tree_choices=mc_sim_7b_63,
):
if temperature>1e-5:
logits_processor=prepare_logits_processor(temperature=temperature,top_p=top_p,top_k=top_k)
else:
logits_processor=None
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
tree_buffers = self.tree_buffers
else:
tree_buffers = generate_tree_buffers(
tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
)
self.tree_buffers = tree_buffers
self.tree_choices = tree_choices
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
tree_logits, logits, hidden_state, sample_token = initialize_tree(
input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
)
new_token = 0
for idx in range(max_length):
candidates, cart_candidates_prob, tree_candidates = generate_candidates(
tree_logits,
tree_buffers["tree_indices"],
tree_buffers["retrieve_indices"],
sample_token,
logits_processor
)
logits, hidden_state_new, outputs = tree_decoding(
self,
tree_candidates,
past_key_values,
tree_buffers["tree_position_ids"],
input_ids,
tree_buffers["retrieve_indices"],
)
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor, cart_candidates_prob
)
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
tree_buffers["retrieve_indices"],
logits_processor,
logits,
tree_logits,
new_token,
past_key_values_data,
current_length_data,
self,
hidden_state,
hidden_state_new,
sample_p
)
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
return input_ids
if new_token > max_new_tokens:
return input_ids
if input_ids.shape[1] > max_length:
return input_ids
@torch.no_grad()
def ea_generate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_steps=512,
tree_choices=mc_sim_7b_63,
):
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
tree_buffers = self.tree_buffers
else:
tree_buffers = generate_tree_buffers(
tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
)
self.tree_buffers = tree_buffers
self.tree_choices = tree_choices
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
tree_logits, logits, hidden_state, sample_token = initialize_tree(
input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
)
new_token = 0
for idx in range(max_steps):
candidates, cart_candidates_prob, tree_candidates = generate_candidates(
tree_logits,
tree_buffers["tree_indices"],
tree_buffers["retrieve_indices"],
sample_token,
logits_processor
)
logits, hidden_state_new, outputs = tree_decoding(
self,
tree_candidates,
past_key_values,
tree_buffers["tree_position_ids"],
input_ids,
tree_buffers["retrieve_indices"],
)
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor, cart_candidates_prob
)
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
tree_buffers["retrieve_indices"],
logits_processor,
logits,
tree_logits,
new_token,
past_key_values_data,
current_length_data,
self,
hidden_state,
hidden_state_new,
sample_p
)
yield input_ids
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > 1024:
break
if input_ids.shape[1] > 1960:
break
@torch.no_grad()
def naive_generate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_steps=512,
tree_choices=mc_sim_7b_63,
):
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
tree_buffers = self.tree_buffers
else:
tree_buffers = generate_tree_buffers(
tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
)
self.tree_buffers = tree_buffers
self.tree_choices = tree_choices
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
new_token = 0
for idx in range(max_steps):
input_id = outputs.logits[:, -1:].argmax(dim=-1)
outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
input_ids = torch.cat([input_ids, input_id], dim=-1)
yield input_ids
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > 1024:
break
if input_ids.shape[1] > 1960:
break