deep-thinking / models /meta_optimizer.py
jx-yang's picture
<ADD> +app
9d21d47
raw
history blame
2.21 kB
import torch
class MomentumOptim:
def __init__(self, step_size=0.01, momentum=0.9):
self.step_size = step_size
self.momentum = momentum
self.m = None # velocity
def init(self):
self.m = None
def upd_m(self, old_m, g):
return g + self.momentum * old_m
def upd(self, old_x, m):
return old_x + self.step_size * m
def __call__(self, old_xs, new_xs):
pesudo_gs = [new_x - old_x for old_x, new_x in zip(old_xs, new_xs)]
if not self.m:
self.m = pesudo_gs
else:
self.m = [self.upd_m(old_m, g) for old_m, g in zip(self.m, pesudo_gs)]
updated_kv = [self.upd(old_x, m) for old_x, m in zip(old_xs, self.m)]
return updated_kv
class AttnOptimWrapper:
def __init__(self, llm, model_type, optimizer="momentum", **optimizer_args):
self.model = llm
self.kv = None
self.model_type = model_type
if optimizer == "momentum":
self.optim_k = MomentumOptim(**optimizer_args)
self.optim_v = MomentumOptim(**optimizer_args)
else:
raise ValueError()
def init(self):
self.optim_k.init()
self.optim_v.init()
@torch.no_grad()
def step(self, ctx_ids):
L = len(ctx_ids)
ctx_ids = ctx_ids.unsqueeze(0) # [1, L]
mask = torch.ones_like(ctx_ids)
if self.kv is not None:
mask = mask.repeat(1, 2) # [1, 2*L]
next_kv = self.model(
input_ids=ctx_ids,
attention_mask=mask,
past_key_values=self.kv,
use_cache=True,
).past_key_values # kv @ (old_ctx + new_ctx)
cur_kv = []
for layer_k, layer_v in next_kv:
# [B, num_head, 2*L, head_hidden]
cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]]) # kv @ (new_ctx)
if not self.kv:
self.kv = cur_kv
else:
old_ks, old_vs = zip(*self.kv)
cur_ks, cur_vs = zip(*cur_kv)
upd_ks = self.optim_k(old_ks, cur_ks)
upd_vs = self.optim_v(old_vs, cur_vs)
self.kv = list(zip(upd_ks, upd_vs))
return self.kv