import copy import os import sys dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) import contextlib import torch.utils.checkpoint from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from PIL import Image from .modeling_perceive_sampler import BertConfig, BertLMHeadModel from .modeling_vit import * from .modeling_InternLM import * from .modeling_utils import * from transformers.utils import logging logger = logging.get_logger(__name__) class InternLMXComposerForCausalLM(PreTrainedModel): config_class = InternLMXComposerConfig _auto_class = "AutoModelForCausalLM" meta_instruction = """meta instruction You are an AI assistant whose name is 浦语. - 浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. - 浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文. conversation """ gen_config = dict( num_beams=5, do_sample=False, min_length=1, repetition_penalty=1.5, length_penalty=1.0, temperature=1.0, max_new_tokens=500, ) def __init__(self, config): super().__init__(config) self.max_length = config.max_length rank0_print('Init VIT ... ', end='') self.visual_encoder = create_eva_vit_g() self.ln_vision = LayerNorm(self.visual_encoder.num_features) rank0_print('Done') rank0_print('Init Perceive Sampler ... ', end='') with all_logging_disabled(): self.Qformer, self.query_tokens = self.init_qformer( config.num_query_token, self.visual_encoder.num_features) self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.Qformer.cls = None rank0_print('Done') rank0_print('Init InternLM ... ', end='') self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096])) self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096])) self.flag_image_start.requires_grad = False self.flag_image_end.requires_grad = False internlm_lora = config.internlm_lora self.internlm_lora = internlm_lora setattr(InternLMForCausalLM, 'lora_cfg', internlm_lora) if int(torch.__version__[0]) == 1: self.internlm_model = InternLMForCausalLM._from_config(config).to( torch.float16) else: assert int(torch.__version__[0]) == 2 # speed up init llm with torch.device('meta'): self.internlm_model = InternLMForCausalLM._from_config(config) # self.internlm_model.to_empty(device=config.device).to(torch.float16) # self.internlm_model.to(config.device) for n, m in self.internlm_model.named_modules(): if 'lora' in n: m.float() self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size, self.internlm_model.config.hidden_size) rank0_print('Done') self.vis_processor = transforms.Compose([ transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) self.tokenizer = None self.eoh = '' # end of human self.eoa = '' # end of assistant stop_words_ids = [ torch.tensor([103027]).to(config.device), torch.tensor([103028]).to(config.device), ] stopping_criteria = StoppingCriteriaList( [StoppingCriteriaSub(stops=stop_words_ids)]) self.gen_config['stopping_criteria'] = stopping_criteria self.supports_gradient_checkpointing = True def get_input_embeddings(self): return self.internlm_model.get_input_embeddings() def _set_gradient_checkpointing(self, module, value=False): if value: self.internlm_model.apply( partial(self.internlm_model._set_gradient_checkpointing, value=True)) def maybe_autocast(self, dtype=torch.float16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext() @classmethod def init_qformer(cls, num_query_token, vision_width, cross_attention_freq=2, pretrain=True): encoder_config = BertConfig() encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size)) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens def encode_img(self, image): if image is None: return None if isinstance(image, str): image = Image.open(image).convert("RGB") image = self.vis_processor(image).unsqueeze(0).to(self.device) else: assert isinstance(image, torch.Tensor) device = image.device with self.maybe_autocast(): image_embeds = self.ln_vision( self.visual_encoder(image)).to(device) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) inputs_internlm = self.internlm_proj( query_output.last_hidden_state) inputs_internlm = torch.cat([ self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1), inputs_internlm, self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1) ], dim=1) return inputs_internlm def encode_text(self, text, add_special_tokens=False): text_token_ids = self.tokenizer( text, return_tensors='pt', add_special_tokens=add_special_tokens, ).input_ids.to(self.device) text_embeds = self.internlm_model.model.embed_tokens(text_token_ids) return text_embeds def decode_text(self, out_embeds): out_text = self.tokenizer.batch_decode(out_embeds, skip_special_tokens=True)[0] out_text = out_text.split(self.eoa)[0] return out_text def wrap_text(self, user_text, bot_text='', add_special=True): if add_special: eoh = self.eoh else: eoh = '' text = f' <|User|>:{user_text} \n{eoh} <|Bot|>:{bot_text}' return text def get_gen_args(self, **kwargs): new_kargs = copy.deepcopy(self.gen_config) new_kargs.update(kwargs) return new_kargs def generate(self, text, image=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_img(image) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds) out_embeds = self.internlm_model.generate( inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) return out_text def chat(self, text, image=None, history=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_img(image) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, history=history) out_embeds = self.internlm_model.generate( inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) # trunc at eoh and eoa clean_out_text_token_ids = self.tokenizer( out_text, return_tensors='pt').input_ids.to(self.device) clean_out_text_embeds = self.internlm_model.model.embed_tokens( clean_out_text_token_ids) clean_prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, add_special=False) cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], dim=1) if history is None: history = [] history.append(cur_history) return out_text, history def wrap_prompt(self, text_embeds, img_embeds=None, history=None, add_special=True): if add_special: if history is None: prompt_segs = [ self.meta_instruction + ' <|User|>:', f'\n{self.eoh} <|Bot|>:' ] else: prompt_segs = [' <|User|>:', f'\n{self.eoh} <|Bot|>:'] else: prompt_segs = [' <|User|>:', ' <|Bot|>:'] # used in wrap history prompt_seg_embeds = [] for i, seg in enumerate(prompt_segs): if history is not None: add_special_tokens = False else: add_special_tokens = i == 0 seg_embeds = self.encode_text( seg, add_special_tokens=add_special_tokens) prompt_seg_embeds.append(seg_embeds) if img_embeds is None: img_embeds = text_embeds.new_empty(text_embeds.size(0), 0, text_embeds.size(-1)) prompt_seg_embeds = [ prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1] ] prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) if history is not None: prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) return prompt_embeds ###################### # code for training ###################### def prompt_wrap(self, img_embeds, prompt): batch_size = img_embeds.shape[0] p_before, p_after = prompt.split('') p_before_tokens = self.tokenizer(p_before, return_tensors="pt", add_special_tokens=True).to( img_embeds.device) p_before_embeds = self.internlm_model.model.embed_tokens( p_before_tokens.input_ids).expand(batch_size, -1, -1) wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds], dim=1) wrapped_atts_img = torch.ones(wrapped_img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device) wrapped_target = torch.ones( batch_size, wrapped_img_embeds.shape[1], dtype=torch.long).to( img_embeds.device) * -100 return wrapped_img_embeds, wrapped_atts_img, wrapped_target def align_text(self, samples, has_img=False): ### add eos and eoa text_new = [] if has_img: ### remove the first user to wrap image features text = [ t.replace("", "").split("<|User|>:", 1)[-1].lstrip() for t in samples["text_input"] ] else: text = [t for t in samples["text_input"]] text = [t + self.eoa + ' ' for t in text] for i in range(len(text)): temp = text[i] temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>') temp = temp.replace(' <|User|>', self.eoa + ' <|User|>') if temp.find(self.eoh) > temp.find(self.eoa): temp = temp.replace(self.eoa, '', 1) text_new.append(temp) return text_new def text2emb(self, text): to_regress_tokens = self.tokenizer(text, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_length, add_special_tokens=False).to( self.device) targets = self.mask_human_targets(to_regress_tokens.input_ids) targets = targets.to(self.device) return to_regress_tokens, targets def mask_human_targets(self, input_ids, pure=False): target_batch = [] for bs in range(input_ids.shape[0]): cur_idx = 0 ids = input_ids[bs] targets = copy.deepcopy(ids) last_eoa = 0 last_eoh = 0 for i, temp_id in enumerate(ids): if temp_id == 103027: #### end of human targets[cur_idx:i + 6] = -100 cur_idx = i + 6 last_eoh = i elif temp_id == 103028: ### end of assistant cur_idx = i + 1 last_eoa = i elif temp_id == 2: ### eos and following pad targets[i + 1:] = -100 #### loss on eos, but not on pad break if temp_id != 2 and last_eoa > last_eoh: ### trunction, end at last question targets[last_eoa + 1:] = -100 #### mask all after the last answer target_batch.append(targets.unsqueeze(0)) target_batch = torch.cat(target_batch, dim=0) return target_batch def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs): samples = kwargs.get('samples') has_img = 'images' in samples.keys() ### encode text text = self.align_text(samples, has_img=has_img) to_regress_tokens, targets = self.text2emb(text) to_regress_embeds = self.internlm_model.model.embed_tokens( to_regress_tokens.input_ids) attention_mask = to_regress_tokens.attention_mask if has_img: header = samples["text_input"][0].split(' <|User|>:')[0] prompt = header + ' <|User|>:' ### encode image image = samples["image"] img_embeds = self.encode_img(image) img_embeds, atts_img, wrapped_target = self.prompt_wrap( img_embeds, prompt) ### combine text and image to_regress_embeds = torch.cat([img_embeds, to_regress_embeds], dim=1) attention_mask = torch.cat([atts_img, attention_mask], dim=1) targets = torch.cat([wrapped_target, targets], dim=1) outputs = self.internlm_model( inputs_embeds=to_regress_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) return outputs