# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import os import copy import json import random import pathlib import traceback from dataclasses import dataclass, field from typing import Dict, Optional, Sequence, List # torch-related packages # NOTE: torch must be imported before transformers. Otherwise, `Segmentation fault (core dumped)` will occur. import torch from torch.utils.data import Dataset import transformers from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import sys sys.path.append('./') from videollama2.model import * from videollama2.constants import NUM_FRAMES, IGNORE_INDEX, MODAL_INDEX_MAP from videollama2.mm_utils import tokenizer_multimodal_token, process_video, process_image from videollama2.videollama2_trainer import (VideoLLaMA2Trainer, get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, find_all_linear_names, safe_save_model_for_hf_trainer ) # NOTE: fast tokenizer warning issue: https://github.com/huggingface/transformers/issues/5486 os.environ["TOKENIZERS_PARALLELISM"] = "true" local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) def set_seed(seed=42): """ Set the random seed for reproducible results. :param seed: An integer value to be used as the random seed. """ torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for multi-GPU setups torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @dataclass class ModelArguments: # LLM Arguments model_type: Optional[str] = field(default="videollama2", metadata={"help": "Model type selected in the list: " + ", ".join(VLLMs.keys())}) model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5") version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."}) freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."}) # Connector Arguments mm_projector_type: Optional[str] = field(default='linear') tune_mm_mlp_adapter: bool = field(default=False) pretrain_mm_mlp_adapter: Optional[str] = field(default=None) # Vision tower Arguments vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) mm_vision_select_feature: Optional[str] = field(default="patch") @dataclass class DataArguments: # Path Arguments data_path: str = field(default=None, metadata={"help": "Path to the training data."}) # image_folder: Optional[str] = field(default=None) # video_folder: Optional[str] = field(default=None) data_folder: Optional[str] = field(default=None) # Loading Arguments is_multimodal: bool = False lazy_preprocess: bool = False num_frames: Optional[int] = field(default=None) # Preprocess Arguments image_aspect_ratio: str = 'square' @dataclass class TrainingArguments(transformers.TrainingArguments): optim: str = field(default="adamw_torch") mm_projector_lr: Optional[float] = None freeze_mm_mlp_adapter: bool = field(default=False) remove_unused_columns: bool = field(default=False) cache_dir: Optional[str] = field(default=None) # Training Data Arguments group_by_modality_length: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) # Lora or Quant Arguments double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, modal_token: str = None, ) -> Dict: roles = {"human": "user", "gpt": "assistant"} conversations = [] input_ids = [] targets = [] for source in sources: # 1. apply chat template for input conversation assert len(source) == 2 assert modal_token in source[0]['value'] message = [ {'role': 'user', 'content': modal_token}, {'role': 'assistant', 'content': source[1]['value']} ] conversation = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) # 2. tokenize conversations input_ids.append(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) # 3. make targets targets.append(copy.deepcopy(input_ids[-1])) instruction = tokenizer.apply_chat_template(message[:1], tokenize=False, add_generation_prompt=True) instruction_len = len(tokenizer_multimodal_token(instruction, tokenizer, modal_token, return_tensors='pt')) targets[-1][:instruction_len] = IGNORE_INDEX # print("instruction: ----------------") # print(instruction) # print("conversation: ----------------") # print(conversation) # print("training targets: ----------------") # print(tokenizer.decode(targets[-1][instruction_len:])) # print(input_ids[-1]) # print(targets[-1]) return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, modal_token: str = None, ) -> Dict: roles = {"human": "user", "gpt": "assistant"} # Apply prompt templates conversations = [] input_ids = [] targets = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != "user": # Skip the first one if it is not from human source = source[1:] message = [{'role': roles[sentence['from']], 'content': sentence['value']} for sentence in source] conversation = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) input_ids.append(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) targets.append(copy.deepcopy(input_ids[-1])) assert len(source) % 2 == 0, f"Invalid conversation length {len(source)}." cur = 0 message = [] for idx, sentence in enumerate(source): if idx % 2 == 1: tmp_message = [ {'role': roles[source[idx-1]['from']], 'content': source[idx-1]['value']}, {'role': roles[sentence['from']], 'content': sentence['value']} ] instruction = tokenizer.apply_chat_template(message + tmp_message[:1], tokenize=False, add_generation_prompt=True) conversation = tokenizer.apply_chat_template(message + tmp_message, tokenize=False, add_generation_prompt=False) instruction_len = len(tokenizer_multimodal_token(instruction, tokenizer, modal_token, return_tensors='pt')) conversation_len = len(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) targets[-1][cur:instruction_len] = IGNORE_INDEX cur = conversation_len message += tmp_message return dict(input_ids=input_ids, labels=targets) def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments, modal_token: str = None, ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources assert modal_token in MODAL_INDEX_MAP, f"Unsupported modal token {modal_token}." for source in sources: for sentence in source: if modal_token in sentence['value']: sentence['value'] = sentence['value'].replace(modal_token, '').strip() sentence['value'] = modal_token + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() replace_token = modal_token # TODO: fix this for multimedia, e.g.,