# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import datetime import shutil # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import time from pathlib import Path import click import torch import torch.nn as nn import torch.nn.functional as F from fish_speech.models.text2semantic.llama import find_multiple from tools.llama.generate import load_model ##### Quantization Primitives ###### def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # assumes symmetric quantization # assumes axis == 0 # assumes dense memory format # TODO(future): relax ^ as needed # default setup for affine quantization of activations eps = torch.finfo(torch.float32).eps # get min and max min_val, max_val = torch.aminmax(x, dim=1) # calculate scales and zero_points based on min and max # reference: https://fburl.com/code/srbiybme min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) device = min_val_neg.device # reference: https://fburl.com/code/4wll53rk max_val_pos = torch.max(-min_val_neg, max_val_pos) scales = max_val_pos / (float(quant_max - quant_min) / 2) # ensure scales is the same dtype as the original tensor scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) # quantize based on qmin/qmax/scales/zp # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 x_div = x / scales.unsqueeze(-1) x_round = torch.round(x_div) x_zp = x_round + zero_points.unsqueeze(-1) quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) return quant, scales, zero_points def get_group_qparams(w, n_bit=4, groupsize=128): # needed for GPTQ with padding if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 to_quant = w.reshape(-1, groupsize) assert torch.isnan(to_quant).sum() == 0 max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) max_int = 2**n_bit - 1 scales = (max_val - min_val).clamp(min=1e-6) / max_int zeros = min_val + scales * (2 ** (n_bit - 1)) return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( torch.bfloat16 ).reshape(w.shape[0], -1) def pack_scales_and_zeros(scales, zeros): assert scales.shape == zeros.shape assert scales.dtype == torch.bfloat16 assert zeros.dtype == torch.bfloat16 return ( torch.cat( [ scales.reshape(scales.size(0), scales.size(1), 1), zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, ) .transpose(0, 1) .contiguous() ) def unpack_scales_and_zeros(scales_and_zeros): assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 assert scales_and_zeros.dtype == torch.float return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): assert groupsize > 1 # needed for GPTQ single column quantize if groupsize > w.shape[-1] and scales.shape[-1] == 1: groupsize = w.shape[-1] assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 to_quant = w.reshape(-1, groupsize) assert torch.isnan(to_quant).sum() == 0 scales = scales.reshape(-1, 1) zeros = zeros.reshape(-1, 1) min_val = zeros - scales * (2 ** (n_bit - 1)) max_int = 2**n_bit - 1 min_int = 0 w_int32 = ( to_quant.sub(min_val) .div(scales) .round() .clamp_(min_int, max_int) .to(torch.int32) .reshape_as(w) ) return w_int32 def group_quantize_tensor(w, n_bit=4, groupsize=128): scales, zeros = get_group_qparams(w, n_bit, groupsize) w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) scales_and_zeros = pack_scales_and_zeros(scales, zeros) return w_int32, scales_and_zeros def group_dequantize_tensor_from_qparams( w_int32, scales, zeros, n_bit=4, groupsize=128 ): assert groupsize > 1 # needed for GPTQ single column dequantize if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: groupsize = w_int32.shape[-1] assert w_int32.shape[-1] % groupsize == 0 assert w_int32.dim() == 2 w_int32_grouped = w_int32.reshape(-1, groupsize) scales = scales.reshape(-1, 1) zeros = zeros.reshape(-1, 1) w_dq = ( w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) ) return w_dq def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): scales, zeros = unpack_scales_and_zeros(scales_and_zeros) return group_dequantize_tensor_from_qparams( w_int32, scales, zeros, n_bit, groupsize ) class QuantHandler: def __init__(self, mod): self.mod = mod def create_quantized_state_dict(self) -> "StateDict": pass def convert_for_runtime(self) -> "nn.Module": pass ##### Weight-only int8 per-channel quantized code ###### def replace_linear_weight_only_int8_per_channel(module): for name, child in module.named_children(): if isinstance(child, nn.Linear): setattr( module, name, WeightOnlyInt8Linear(child.in_features, child.out_features), ) else: replace_linear_weight_only_int8_per_channel(child) class WeightOnlyInt8QuantHandler: def __init__(self, mod): self.mod = mod @torch.no_grad() def create_quantized_state_dict(self): cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): if isinstance(mod, torch.nn.Linear): int8_weight, scales, _ = dynamically_quantize_per_channel( mod.weight.float(), -128, 127, torch.int8 ) cur_state_dict[f"{fqn}.weight"] = int8_weight cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) return cur_state_dict def convert_for_runtime(self): replace_linear_weight_only_int8_per_channel(self.mod) return self.mod class WeightOnlyInt8Linear(torch.nn.Module): __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features self.register_buffer( "weight", torch.empty((out_features, in_features), dtype=torch.int8) ) self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales ##### weight only int4 per channel groupwise quantized code ###### def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( weight_int32, inner_k_tiles ) return weight_int4pack, scales_and_zeros def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) c = torch.ops.aten._weight_int4pack_mm( x, weight_int4pack, groupsize, scales_and_zeros ) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 def replace_linear_int4(module, groupsize, inner_k_tiles, padding): for name, child in module.named_children(): if isinstance(child, nn.Linear): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): setattr( module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, ), ) elif padding: setattr( module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, ), ) else: replace_linear_int4(child, groupsize, inner_k_tiles, padding) class WeightOnlyInt4QuantHandler: def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): self.mod = mod self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding = padding assert groupsize in [32, 64, 128, 256] assert inner_k_tiles in [2, 4, 8] @torch.no_grad() def create_quantized_state_dict(self): cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): if isinstance(mod, torch.nn.Linear): assert not mod.bias out_features = mod.out_features in_features = mod.in_features assert out_features % 8 == 0, "require out_features % 8 == 0" print(f"linear: {fqn}, in={in_features}, out={out_features}") weight = mod.weight.data if not _check_linear_int4_k( in_features, self.groupsize, self.inner_k_tiles ): if self.padding: import torch.nn.functional as F print( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, 1024) weight = F.pad( weight, pad=(0, padded_in_features - in_features) ) else: print( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) continue ( weight_int4pack, scales_and_zeros, ) = prepare_int4_weight_and_scales_and_zeros( weight.to(torch.bfloat16).to("cuda"), self.groupsize, self.inner_k_tiles, ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") return cur_state_dict def convert_for_runtime(self): replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) return self.mod class WeightOnlyInt4Linear(torch.nn.Module): __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( self, in_features: int, out_features: int, bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, ) -> None: super().__init__() self.padding = padding if padding: self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) self.in_features = in_features self.out_features = out_features assert not bias, "require bias=False" self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles assert out_features % 8 == 0, "require out_features % 8 == 0" assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", torch.empty( ( out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2, ), dtype=torch.int32, ), ) self.register_buffer( "scales_and_zeros", torch.empty( (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(torch.bfloat16) if self.padding: import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) def generate_folder_name(): now = datetime.datetime.now() folder_name = now.strftime("%Y%m%d_%H%M%S") return folder_name @click.command() @click.option( "--checkpoint-path", type=click.Path(path_type=Path, exists=True), default="checkpoints/fish-speech-1.4", ) @click.option( "--mode", type=str, default="int8", help="type of quantization to perform" ) @click.option( "--groupsize", type=int, default=128, help="Group size for int4 quantization." ) @click.option("--timestamp", type=str, default="None", help="When to do quantization") def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: device = "cpu" precision = torch.bfloat16 print("Loading model ...") t0 = time.time() model, _ = load_model( checkpoint_path=checkpoint_path, device=device, precision=precision, compile=False, ) vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" now = timestamp if timestamp != "None" else generate_folder_name() if mode == "int8": print( "Quantizing model weights for int8 weight-only symmetric per-channel quantization" ) quant_handler = WeightOnlyInt8QuantHandler(model) quantized_state_dict = quant_handler.create_quantized_state_dict() dir_name = checkpoint_path dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) if (dst_name / vq_model).exists(): (dst_name / vq_model).unlink() quantize_path = dst_name / "model.pth" elif mode == "int4": print( "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" ) quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) quantized_state_dict = quant_handler.create_quantized_state_dict() dir_name = checkpoint_path dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) if (dst_name / vq_model).exists(): (dst_name / vq_model).unlink() quantize_path = dst_name / "model.pth" else: raise ValueError( f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" ) print(f"Writing quantized weights to {quantize_path}") quantize_path.unlink(missing_ok=True) # remove existing file if one already there torch.save(quantized_state_dict, quantize_path) print(f"Quantization complete took {time.time() - t0:.02f} seconds") if __name__ == "__main__": quantize()