File size: 5,304 Bytes
cfde609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import transformers
from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig, QuantoConfig, GenerationConfig
import torch
import safetensors
import argparse
import os
import json
from PIL import Image

"""
 usage:
    export SAFETENSORS_FAST_GPU=1
    python main.py --quant_type int8 --world_size 8 --model_id <model_path> --image_path <image_path>
"""

def generate_quanto_config(hf_config: AutoConfig, quant_type: str):
    QUANT_TYPE_MAP = {
        "default": None,
        "int8": QuantoConfig(
            weights="int8",
            modules_to_not_convert=[
                "vision_tower",
                "image_newline",
                "multi_modal_projector",
                "lm_head",
                "embed_tokens",
            ] + [f"model.layers.{i}.coefficient" for i in range(hf_config.text_config.num_hidden_layers)]
            + [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.text_config.num_hidden_layers)]
        ),
    }
    return QUANT_TYPE_MAP[quant_type]

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"])
    parser.add_argument("--model_id", type=str, required=True)
    parser.add_argument("--world_size", type=int, required=True)
    parser.add_argument("--image_path", type=str, required=True)
    return parser.parse_args()

def check_params(args, hf_config: AutoConfig):
    if args.quant_type == "int8":
        assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs"

    assert hf_config.text_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.text_config.num_hidden_layers}) must be divisible by world_size({args.world_size})"

@torch.no_grad()
def main():
    args = parse_args()
    print("\n=============== Argument ===============")
    for key in vars(args):
        print(f"{key}: {vars(args)[key]}")
    print("========================================")

    model_id = args.model_id

    hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    quantization_config = generate_quanto_config(hf_config, args.quant_type)

    check_params(args, hf_config)

    model_safetensors_index_path = os.path.join(model_id, "model.safetensors.index.json")
    with open(model_safetensors_index_path, "r") as f:
        model_safetensors_index = json.load(f)
    weight_map = model_safetensors_index['weight_map']
    vision_map = {}
    for key, value in weight_map.items():
        if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
            new_key = key.replace('.weight','').replace('.bias','')
            if new_key not in vision_map:
                vision_map[new_key] = value
    device_map = {
        'language_model.model.embed_tokens': 'cuda:0',
        'language_model.model.norm': f'cuda:{args.world_size - 1}',
        'language_model.lm_head': f'cuda:{args.world_size - 1}'
    }
    for key, value in vision_map.items():
        device_map[key] = f'cuda:0'
    device_map['vision_tower.vision_model.post_layernorm'] = f'cuda:0'
    layers_per_device = hf_config.text_config.num_hidden_layers // args.world_size
    for i in range(args.world_size):
        for j in range(layers_per_device):
            device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'

    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-VL-01 model."}]},
        {"role": "user", "content": [{"type": "image", "image": "placeholder"},{"type": "text", "text": "Describe this image."}]},
    ]
    prompt = processor.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    print(f"prompt: \n{prompt}")
    raw_image = Image.open(args.image_path)
    model_inputs = processor(images=[raw_image], text=prompt, return_tensors='pt').to('cuda').to(torch.bfloat16)

    quantized_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="bfloat16",
        device_map=device_map,
        quantization_config=quantization_config,
        trust_remote_code=True,
        offload_buffers=True,
    )
    generation_config = GenerationConfig(
        max_new_tokens=100,
        eos_token_id=200020,
        use_cache=True,
    )
    generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config)
    print(f"generated_ids: {generated_ids}")
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(response)
    # The image depicts a single, whole apple with a rich, red color. The apple appears to be fresh, with a smooth, glossy skin that reflects light, indicating its juiciness. The surface of the apple is dotted with small, light-colored

def query_safetensors(path):
    safetensor = safetensors.torch.load_file(path)
    for key in safetensor.keys():
        print(key, safetensor[key].shape)
if __name__ == "__main__":
    main()