EAGLE / model /utils_c.py
yuhuili's picture
Upload 10 files
687d97d
raw
history blame
6.1 kB
import torch
TOPK = 10 # topk for sparse tree
def pad_path(path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
class node:
def __init__(self,parent=None,value=None,dict_key=None):
self.parent=parent
self.value=value
if parent:
self.depth=parent.depth+1
parent.children.append(self)
else:
self.depth=0
self.children=[]
self.dict_key=dict_key
def is_leaf(self):
return len(self.children)==0
def all_index(self):
if not self.parent.parent:
return [self.index]
else:
return self.parent.all_index()+[self.index]
class Tree:
def __init__(self,tree_list):
sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x))
self.root=node()
self.node_dic={}
for tree_node in sorted_tree_list:
cur_value=tree_node[-1]
if len(tree_node)==1:
cur_node=node(parent=self.root,value=cur_value,dict_key=tuple(tree_node))
else:
cur_parent=self.node_dic[tuple(tree_node[:-1])]
cur_node = node(parent=cur_parent, value=cur_value,dict_key=tuple(tree_node))
self.node_dic[tuple(tree_node)] = cur_node
self.indexnode()
def max_depth(self):
return max([item.depth for item in self.node_dic.values()])
def num_node_wchild(self):
num_c=0
for item in self.node_dic.values():
if not item.is_leaf():
num_c+=1
return num_c
def get_node_wchild(self):
ns=[]
for item in self.node_dic.values():
if not item.is_leaf():
ns.append(item)
return ns
def indexnode(self):
cur_index=0
for key in self.node_dic:
cur_node=self.node_dic[key]
if not cur_node.is_leaf():
cur_node.index=cur_index
cur_index+=1
def generate_tree_buffers(tree_choices, device="cuda"):
tree=Tree(tree_choices)
sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
tree_len = tree.num_node_wchild()
max_depth=tree.max_depth()
nodes_wc=tree.get_node_wchild()
depth_counts=[0 for _ in range(max_depth-1)]
for x in nodes_wc:
depth_counts[x.depth-1]+=1
depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))]
tree_attn_mask = torch.eye(tree_len, tree_len)
for id,x in enumerate(nodes_wc):
tree_attn_mask[id,x.all_index()]=1
tree_attn_mask_list0=[tree_attn_mask[:ml,:ml] for ml in depth_counts_sum]
tree_attn_mask_list=[]
for id,x in enumerate(tree_attn_mask_list0):
x=x[-depth_counts[id]:]
tree_attn_mask_list.append(x)
tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
repeat_nums=[[] for _ in depth_counts]
start = 0
bias = 0
for i in range(len(depth_counts)):
bias = 0
repeat_j=0
for j in range(depth_counts[i]):
cur_node = nodes_wc[start + j]
cur_parent = cur_node.parent
if j != 0:
if cur_parent != parent:
bias += 1
parent = cur_parent
repeat_nums[i].append(j-repeat_j)
repeat_j=j
else:
parent = cur_parent
tree_indices_list[i][j] = cur_node.value + TOPK * (bias)
repeat_nums[i].append(j - repeat_j+1)
start += depth_counts[i]
position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
# start = 0
# for i in range(len(depth_counts)):
# position_ids[start: start + depth_counts[i]] = i
# start += depth_counts[i]
tree_buffers = {
"attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list],
"tree_indices": tree_indices_list,
"position_ids":position_ids,
"repeat_nums":repeat_nums
}
# Move the tensors in the dictionary to the specified device
tree_buffers = {
k: [i.clone().to(device) for i in v]
if isinstance(v[0], torch.Tensor)
else (
torch.tensor(v, device=device)
if isinstance(v, torch.Tensor)
else v
)
for k, v in tree_buffers.items()
}
return tree_buffers
def reset_past_key_values(passed_key_values):
"""
Resets the current lengths in the passed key-values to zero.
This function is designed to be used during the evaluation of a baseline model.
It iterates through each layer's key-values and sets their current lengths to zero,
effectively resetting their state.
Args:
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
Returns:
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
"""
for i in range(len(passed_key_values)):
for j in range(2):
passed_key_values[i][j].current_length.fill_(0)
return passed_key_values
if __name__=="__main__":
from choices import mc_sim_7b_63
a=generate_tree_buffers(mc_sim_7b_63)
print(a)