Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from huggingface_hub import snapshot_download | |
from peft import get_peft_model, LoraConfig | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class LeoAgentLLM(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# LLM | |
if cfg.launch_mode == 'hf': | |
llm_cfg_path = snapshot_download(cfg.model.llm.hf_cfg_path) | |
else: | |
llm_cfg_path = cfg.model.llm.local_cfg_path | |
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False, | |
truncation_side=cfg.model.llm.truncation_side) | |
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'}) | |
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'}) | |
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'}) | |
self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) | |
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) | |
for param in self.llm_model.parameters(): | |
param.requires_grad = False | |
self.llm_model.eval() | |
self.llm_model.train = disabled_train | |
# LoRA-based LLM fine-tuning | |
if cfg.model.llm.lora.flag: | |
lora_config = LoraConfig( | |
r=cfg.model.llm.lora.rank, | |
lora_alpha=cfg.model.llm.lora.alpha, | |
target_modules=cfg.model.llm.lora.target_modules, | |
lora_dropout=cfg.model.llm.lora.dropout, | |
bias='none', | |
modules_to_save=[], | |
) | |
self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config) | |
self.max_context_len = cfg.model.llm.max_context_len | |
def device(self): | |
return list(self.parameters())[0].device | |
def build_right_justified_sequence(self, data_dict): | |
""" | |
Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`. | |
Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>. | |
""" | |
bs = len(data_dict['prompt_before_obj']) | |
self.llm_tokenizer.padding_side = 'left' | |
text_input_tokens_pre = self.llm_tokenizer( | |
data_dict['prompt_before_obj'], | |
return_tensors='pt', | |
padding='longest' | |
).to(self.device) # [PAD, BOS, tokens], (B, T1) | |
text_input_tokens_mid1 = self.llm_tokenizer( | |
data_dict['prompt_middle_1'], | |
return_tensors='pt', | |
padding='longest' | |
).to(self.device) | |
img_tokens = data_dict['img_tokens'].to(self.device) | |
img_masks = data_dict['img_masks'].to(self.device) | |
img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1)) | |
text_input_tokens_mid2 = self.llm_tokenizer( | |
data_dict['prompt_middle_2'], | |
return_tensors='pt', | |
padding='longest' | |
).to(self.device) | |
obj_tokens = data_dict['obj_tokens'].to(self.device) | |
obj_masks = data_dict['obj_masks'].to(self.device) | |
self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted | |
self.llm_tokenizer.truncation_side = 'left' # truncate history | |
text_input_tokens_post = self.llm_tokenizer( | |
data_dict['prompt_after_obj'], | |
return_tensors='pt', | |
padding='longest', | |
truncation=True, | |
max_length=self.max_context_len, | |
).to(self.device) # [BOS, tokens, PAD], (B, T3) | |
# hardcode, remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq" | |
assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \ | |
"prompt_middle should be the same and thus no padding" | |
text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:] | |
text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:] | |
for i in range(bs): | |
if not img_masks[i].any(): | |
# no image input, also mask the text prompt for image tokens | |
text_input_tokens_mid1.attention_mask[i].fill_(0) | |
text_input_tokens_mid2.input_ids[:, 0] = 869 # 1 (bos) -> 869 (β.) | |
text_input_tokens_post.input_ids[:, 0] = 869 # 1 (bos) -> 869 (β.) | |
inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids) | |
inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids) | |
inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids) | |
inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids) | |
# since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first | |
inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1) | |
attn_mask_mid = torch.cat([ | |
text_input_tokens_mid1.attention_mask, img_masks, | |
text_input_tokens_mid2.attention_mask, obj_masks | |
], dim=1) | |
post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1) | |
bs, l1, hidden_dim = inputs_embeds_pre.shape | |
_, l2, _ = inputs_embeds_mid.shape | |
_, l3, _ = inputs_embeds_post.shape | |
inputs_embeds = torch.zeros( | |
bs, l1+l2+l3, hidden_dim | |
).type(inputs_embeds_pre.dtype).to(self.device) | |
attention_mask = torch.zeros( | |
bs, l1+l2+l3 | |
).type(obj_masks.dtype).to(self.device) | |
# assign by chunks | |
for i in range(bs): | |
post_pad_len = post_pad_length[i] | |
if post_pad_len > 0: | |
inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:] | |
attention_mask[i, :post_pad_len] = 0 | |
inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len] | |
attention_mask[i, post_pad_len+l1+l2:] = 1 | |
else: | |
# no padding | |
inputs_embeds[i, -l3:] = inputs_embeds_post[i] | |
attention_mask[i, -l3:] = 1 | |
inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i] | |
attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i] | |
inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i] | |
attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i] | |
return inputs_embeds, attention_mask | |
def generate( | |
self, | |
data_dict, | |
use_nucleus_sampling=False, | |
num_beams=5, | |
max_length=256, | |
min_length=1, | |
repetition_penalty=3.0, | |
length_penalty=1, | |
num_captions=1, | |
temperature=1, | |
): | |
assert 'img_tokens' in data_dict and 'obj_tokens' in data_dict, "Visual features should have been processed offline." | |
inputs_embeds, attention_mask = self.build_right_justified_sequence(data_dict=data_dict) | |
bs = inputs_embeds.shape[0] | |
# give bos token as condition | |
bos_tokens = self.llm_tokenizer( | |
[self.llm_tokenizer.bos_token] * bs, | |
return_tensors='pt', | |
).to(self.device) | |
bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1) | |
bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1) | |
# prepare a `bos_token` | |
bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D) | |
inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D) | |
attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1) | |
outputs = self.llm_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
do_sample=use_nucleus_sampling, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
) | |
outputs[outputs == 0] = 2 # convert output id 0 (unk_token) to 2 (eos_token) | |
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [text.strip() for text in output_text] | |
return output_text | |