File size: 5,304 Bytes
cfde609 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import transformers
from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig, QuantoConfig, GenerationConfig
import torch
import safetensors
import argparse
import os
import json
from PIL import Image
"""
usage:
export SAFETENSORS_FAST_GPU=1
python main.py --quant_type int8 --world_size 8 --model_id <model_path> --image_path <image_path>
"""
def generate_quanto_config(hf_config: AutoConfig, quant_type: str):
QUANT_TYPE_MAP = {
"default": None,
"int8": QuantoConfig(
weights="int8",
modules_to_not_convert=[
"vision_tower",
"image_newline",
"multi_modal_projector",
"lm_head",
"embed_tokens",
] + [f"model.layers.{i}.coefficient" for i in range(hf_config.text_config.num_hidden_layers)]
+ [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.text_config.num_hidden_layers)]
),
}
return QUANT_TYPE_MAP[quant_type]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"])
parser.add_argument("--model_id", type=str, required=True)
parser.add_argument("--world_size", type=int, required=True)
parser.add_argument("--image_path", type=str, required=True)
return parser.parse_args()
def check_params(args, hf_config: AutoConfig):
if args.quant_type == "int8":
assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs"
assert hf_config.text_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.text_config.num_hidden_layers}) must be divisible by world_size({args.world_size})"
@torch.no_grad()
def main():
args = parse_args()
print("\n=============== Argument ===============")
for key in vars(args):
print(f"{key}: {vars(args)[key]}")
print("========================================")
model_id = args.model_id
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
quantization_config = generate_quanto_config(hf_config, args.quant_type)
check_params(args, hf_config)
model_safetensors_index_path = os.path.join(model_id, "model.safetensors.index.json")
with open(model_safetensors_index_path, "r") as f:
model_safetensors_index = json.load(f)
weight_map = model_safetensors_index['weight_map']
vision_map = {}
for key, value in weight_map.items():
if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
new_key = key.replace('.weight','').replace('.bias','')
if new_key not in vision_map:
vision_map[new_key] = value
device_map = {
'language_model.model.embed_tokens': 'cuda:0',
'language_model.model.norm': f'cuda:{args.world_size - 1}',
'language_model.lm_head': f'cuda:{args.world_size - 1}'
}
for key, value in vision_map.items():
device_map[key] = f'cuda:0'
device_map['vision_tower.vision_model.post_layernorm'] = f'cuda:0'
layers_per_device = hf_config.text_config.num_hidden_layers // args.world_size
for i in range(args.world_size):
for j in range(layers_per_device):
device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-VL-01 model."}]},
{"role": "user", "content": [{"type": "image", "image": "placeholder"},{"type": "text", "text": "Describe this image."}]},
]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(f"prompt: \n{prompt}")
raw_image = Image.open(args.image_path)
model_inputs = processor(images=[raw_image], text=prompt, return_tensors='pt').to('cuda').to(torch.bfloat16)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="bfloat16",
device_map=device_map,
quantization_config=quantization_config,
trust_remote_code=True,
offload_buffers=True,
)
generation_config = GenerationConfig(
max_new_tokens=100,
eos_token_id=200020,
use_cache=True,
)
generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config)
print(f"generated_ids: {generated_ids}")
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
# The image depicts a single, whole apple with a rich, red color. The apple appears to be fresh, with a smooth, glossy skin that reflects light, indicating its juiciness. The surface of the apple is dotted with small, light-colored
def query_safetensors(path):
safetensor = safetensors.torch.load_file(path)
for key in safetensor.keys():
print(key, safetensor[key].shape)
if __name__ == "__main__":
main() |