|
import torch, argparse, copy |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear |
|
from marlin import Layer as MarlinLayer |
|
import gc |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-id", type=str) |
|
parser.add_argument("--save-path", type=str) |
|
parser.add_argument("--do-generation", action="store_true") |
|
|
|
def _validate_compatibility(model): |
|
if not hasattr(model.config, "quantization_config"): |
|
raise ValueError("Must be a quantized model to convert to Marlin Format") |
|
quantization_config = model.config.quantization_config |
|
if quantization_config.quant_method != "gptq": |
|
raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}") |
|
if quantization_config.bits != 4: |
|
raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}") |
|
if quantization_config.group_size != 128: |
|
raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}") |
|
if not quantization_config.sym: |
|
raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}") |
|
if quantization_config.desc_act: |
|
raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}") |
|
|
|
@torch.no_grad() |
|
def unpack_4bit_to_32bit_signed(qweight, qzeros): |
|
|
|
unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False) |
|
unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False) |
|
|
|
for row in range(unpacked_weights.shape[0]): |
|
i = row % 8 |
|
unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF |
|
|
|
for col in range(unpacked_zeros.shape[1]): |
|
i = col % 8 |
|
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF |
|
|
|
return unpacked_weights, unpacked_zeros + 1 |
|
|
|
@torch.no_grad() |
|
def dequantize_weight(layer): |
|
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales |
|
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) |
|
group_size = unpacked_qweight.shape[0] // scales.shape[0] |
|
scales = scales.repeat_interleave(group_size, dim=0) |
|
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) |
|
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales |
|
|
|
return unpacked_qweight.T |
|
|
|
@torch.no_grad() |
|
def convert_model(model, verbose=True): |
|
for name, module in model.named_modules(): |
|
if not isinstance(module, QuantLinear): |
|
continue |
|
|
|
if verbose: |
|
print(f"--- Converting Module: {name}") |
|
parent_name = ".".join(name.split(".")[:-1]) |
|
layer_name = name[len(parent_name) + 1:] |
|
|
|
|
|
dequantized_weight = dequantize_weight(module).to(torch.float16) |
|
linear_module = torch.nn.Linear( |
|
in_features=dequantized_weight.shape[1], |
|
out_features=dequantized_weight.shape[0], |
|
bias=False, |
|
dtype=torch.float16, |
|
device="cuda") |
|
linear_module.weight.data.copy_(dequantized_weight) |
|
|
|
|
|
new_module = MarlinLayer( |
|
infeatures=linear_module.in_features, |
|
outfeatures=linear_module.out_features, |
|
groupsize=model.config.quantization_config.group_size) |
|
new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t())) |
|
|
|
|
|
parent_module = model.get_submodule(parent_name) |
|
setattr(parent_module, layer_name, new_module) |
|
|
|
|
|
del dequantized_weight, module |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return model |
|
|
|
@torch.no_grad() |
|
def dequantize_model(model, verbose=True): |
|
for name, module in model.named_modules(): |
|
if not isinstance(module, QuantLinear): |
|
continue |
|
|
|
if verbose: |
|
print(f"--- Dequantizing Module: {name}") |
|
parent_name = ".".join(name.split(".")[:-1]) |
|
layer_name = name[len(parent_name) + 1:] |
|
|
|
|
|
dequantized_weight = dequantize_weight(module) |
|
dequantized_weight_cpu = dequantized_weight.to("cpu") |
|
|
|
|
|
new_module = torch.nn.Linear( |
|
in_features=dequantized_weight_cpu.shape[1], |
|
out_features=dequantized_weight_cpu.shape[0], |
|
bias=False, |
|
dtype=torch.float16) |
|
new_module.weight.data.copy_(dequantized_weight_cpu) |
|
new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data)) |
|
|
|
|
|
parent_module = model.get_submodule(parent_name) |
|
setattr(parent_module, layer_name, new_module) |
|
|
|
|
|
del dequantized_weight, dequantized_weight_cpu, module |
|
torch.cuda.empty_cache() |
|
|
|
return model |
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
model_id = args.model_id |
|
save_path = args.save_path |
|
do_generation = args.do_generation |
|
|
|
print("Loading gptq model...") |
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
print("Validating compatibility...") |
|
_validate_compatibility(model) |
|
|
|
|
|
print("Converting model...") |
|
model = convert_model(model).to("cpu") |
|
|
|
|
|
print("Saving marlin model...") |
|
model.config.quantization_config = { |
|
"group_size": model.config.quantization_config.group_size, |
|
"quant_method": "marlin" |
|
} |
|
model.save_pretrained(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
|
|
if do_generation: |
|
print("Generating sample text...") |
|
model.to("cuda") |
|
prompt = "My favorite song is" |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False) |
|
print(tokenizer.batch_decode(outputs)[0]) |
|
|