import logging from abc import ABC, abstractmethod from typing import List, Dict, Union, Optional import torch from transformers import PretrainedConfig, AutoConfig IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 IMAGE_TOKEN = "" # ---------------------------------------------------------------------- # Visual Tokenizer Configuration # ---------------------------------------------------------------------- class BaseVisualTokenizerConfig(PretrainedConfig): def __init__( self, vocab_size=16384, tokenize_function="softmax", tau=1.0, depths=None, use_indicators=False, drop_cls_token=False, backbone_config: Optional[Union[PretrainedConfig, dict]] = None, hidden_stride: int = 1, hd_booster: Optional[str] = None, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.tokenize_function = tokenize_function self.tau = tau if isinstance(depths, str): depths = [int(x) for x in depths.split('|')] self.depths = depths self.backbone_kwargs = {} self.use_indicators = use_indicators self.drop_cls_token = drop_cls_token if backbone_config is not None: assert isinstance(backbone_config, (PretrainedConfig, dict)), \ (f"expect `backbone_config` to be instance of PretrainedConfig or dict," f" but got {type(backbone_config)} type") if not isinstance(backbone_config, PretrainedConfig): model_type = backbone_config['model_type'] backbone_config.pop('model_type') backbone_config = AutoConfig.for_model(model_type, **backbone_config) self.backbone_config = backbone_config self.hidden_stride = hidden_stride self.hd_booster = hd_booster class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig): model_type = "clip_visual_tokenizer" def __init__(self, **kwargs): super().__init__(**kwargs) if self.depths: assert len(self.depths) == 1 self.backbone_kwargs['num_hidden_layers'] = self.depths[0] class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): model_type = "siglip_visual_tokenizer" def __init__(self, **kwargs): super().__init__(**kwargs) if self.drop_cls_token: logging.warning( f'SiglipVisionModel has no cls token,' f' so `drop_cls_token=True` is ignored and reset to `False`') self.drop_cls_token = False if self.depths: assert len(self.depths) == 1 self.backbone_kwargs['num_hidden_layers'] = self.depths[0] AutoConfig.register("clip_visual_tokenizer", ClipVisualTokenizerConfig) AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) # ---------------------------------------------------------------------- # Ovis Configuration # ---------------------------------------------------------------------- class OvisConfig(PretrainedConfig): model_type = "ovis" def __init__( self, llm_config: Optional[Union[PretrainedConfig, dict]] = None, visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, multimodal_max_length=2048, hidden_size=None, conversation_formatter_class=None, **kwargs ): super().__init__(**kwargs) if llm_config is not None: assert isinstance(llm_config, (PretrainedConfig, dict)), \ (f"expect `llm_config` to be instance of PretrainedConfig or dict," f" but got {type(llm_config)} type") if not isinstance(llm_config, PretrainedConfig): model_type = llm_config['model_type'] llm_config.pop('model_type') llm_config = AutoConfig.for_model(model_type, **llm_config) self.llm_config = llm_config if visual_tokenizer_config is not None: assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ (f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict," f" but got {type(visual_tokenizer_config)} type") if not isinstance(visual_tokenizer_config, PretrainedConfig): model_type = visual_tokenizer_config['model_type'] visual_tokenizer_config.pop('model_type') visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config) self.visual_tokenizer_config = visual_tokenizer_config self.multimodal_max_length = multimodal_max_length self.hidden_size = hidden_size self.conversation_formatter_class = conversation_formatter_class # ---------------------------------------------------------------------- # Conversation Formatter # ---------------------------------------------------------------------- class ConversationFormatter(ABC): support_tokenizer_types = None def __init__(self, tokenizer): tokenizer_type = type(tokenizer).__name__ assert tokenizer_type in self.support_tokenizer_types, \ (f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`,' f' but got `{tokenizer_type}`') self.tokenizer = tokenizer self.image_symbol = IMAGE_TOKEN self.image_token_index = IMAGE_TOKEN_INDEX self.ignore_index = IGNORE_INDEX def _tokenize_with_image_symbol(self, text): text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in text.split(self.image_symbol)] token_ids = [] num_chuck = len(text_chunks) for i, chunk in enumerate(text_chunks): token_ids.extend(chunk) if i < num_chuck - 1: token_ids.append(self.image_token_index) return token_ids @abstractmethod def format(self, conversations: List[Dict], generation_preface=None): pass @abstractmethod def format_query(self, query, generation_preface=""): pass class QwenConversationFormatter(ConversationFormatter): support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast'] def __init__(self, tokenizer): super().__init__(tokenizer) self.from2role = { "system": "<|im_start|>system\n", "human": "<|im_start|>user\n", "gpt": "<|im_start|>assistant\n", } self.gpt_token_num = None self.im_end = "<|im_end|>\n" self.default_system_prompt = "You are a helpful assistant." def format(self, conversations: List[Dict], generation_preface=None): if self.gpt_token_num is None: self.gpt_token_num = len( self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids) if conversations[0]["from"] != "system": conversations.insert(0, { "from": "system", "value": self.default_system_prompt }) if generation_preface is not None: conversations.append({ "from": "gpt", "value": generation_preface }) prompt = "" input_ids = [] labels = [] num_conversation = len(conversations) for i, conversation in enumerate(conversations): frm = conversation["from"] role = self.from2role[frm] message = conversation["value"] text = role + message if i < num_conversation - 1 or generation_preface is None: text += self.im_end prompt += text token_ids = self._tokenize_with_image_symbol(text) input_ids.extend(token_ids) label_ids = [self.ignore_index] * len(token_ids) if frm == "gpt" and generation_preface is None: # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1] labels.extend(label_ids) assert self._tokenize_with_image_symbol(prompt) == input_ids assert len(input_ids) == len(labels) input_ids = torch.tensor(input_ids, dtype=torch.long) labels = torch.tensor(labels, dtype=torch.long) return prompt, input_ids, labels def format_query(self, query, generation_preface=""): prompt, input_ids, _ = self.format([{ "from": "human", "value": query }], generation_preface=generation_preface) return prompt, input_ids class Llama3ConversationFormatter(ConversationFormatter): support_tokenizer_types = ['PreTrainedTokenizerFast'] def __init__(self, tokenizer): super().__init__(tokenizer) self.from2role = { "system": "<|start_header_id|>system<|end_header_id|>\n\n", "human": "<|start_header_id|>user<|end_header_id|>\n\n", "gpt": "<|start_header_id|>assistant<|end_header_id|>\n\n", } self.gpt_token_num = None self.im_end = "<|eot_id|>" self.default_system_prompt = "You are a helpful and honest multimodal assistant." self.bos_token = "<|begin_of_text|>" self.bos_token_ids = None def format(self, conversations: List[Dict], generation_preface=None): if self.gpt_token_num is None: self.gpt_token_num = len( self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids) if self.bos_token_ids is None: self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids if conversations[0]["from"] != "system": conversations.insert(0, { "from": "system", "value": self.default_system_prompt }) if generation_preface is not None: conversations.append({ "from": "gpt", "value": generation_preface }) prompt = "" + self.bos_token input_ids = [] + self.bos_token_ids labels = [] + [IGNORE_INDEX] * len(input_ids) num_conversation = len(conversations) for i, conversation in enumerate(conversations): frm = conversation["from"] role = self.from2role[frm] message = conversation["value"].strip() text = role + message if i < num_conversation - 1 or generation_preface is None: text += self.im_end prompt += text token_ids = self._tokenize_with_image_symbol(text) input_ids.extend(token_ids) label_ids = [self.ignore_index] * len(token_ids) if frm == "gpt": label_ids[self.gpt_token_num:] = token_ids[self.gpt_token_num:] labels.extend(label_ids) assert self._tokenize_with_image_symbol(prompt) == input_ids assert len(input_ids) == len(labels) input_ids = torch.tensor(input_ids, dtype=torch.long) labels = torch.tensor(labels, dtype=torch.long) return prompt, input_ids, labels def format_query(self, query, generation_preface=""): prompt, input_ids, _ = self.format([{ "from": "human", "value": query }], generation_preface=generation_preface) return prompt, input_ids class GemmaConversationFormatter(ConversationFormatter): support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast'] def __init__(self, tokenizer): super().__init__(tokenizer) # Gemma does not support system prompt self.from2role = { "human": "user\n", "gpt": "model\n", } self.gpt_token_num = None self.im_end = "\n" self.bos_token = "" self.bos_token_ids = None def format(self, conversations: List[Dict], generation_preface=None): if self.gpt_token_num is None: self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids) if self.bos_token_ids is None: self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids if conversations[0]["from"] == "system": raise ValueError("Gemma does not support system prompt") if generation_preface is not None: conversations.append({ "from": "gpt", "value": generation_preface }) prompt = "" + self.bos_token input_ids = [] + self.bos_token_ids labels = [] + [IGNORE_INDEX] * len(input_ids) num_conversation = len(conversations) for i, conversation in enumerate(conversations): frm = conversation["from"] role = self.from2role[frm] message = conversation["value"].strip() text = role + message if i < num_conversation - 1 or generation_preface is None: text += self.im_end prompt += text token_ids = self._tokenize_with_image_symbol(text) input_ids.extend(token_ids) label_ids = [self.ignore_index] * len(token_ids) if frm == "gpt": # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1] labels.extend(label_ids) assert self._tokenize_with_image_symbol(prompt) == input_ids assert len(input_ids) == len(labels) input_ids = torch.tensor(input_ids, dtype=torch.long) labels = torch.tensor(labels, dtype=torch.long) return prompt, input_ids, labels def format_query(self, query, generation_preface=""): prompt, input_ids, _ = self.format([{ "from": "human", "value": query }], generation_preface=generation_preface) return prompt, input_ids