QuietImpostor's picture
Conversion script!
1b52fc3 verified
# converter.py
import sys
import torch
import safetensors.torch as st
import logging
import math
import tflite.Model
import tflite.SubGraph
from tflite.TensorType import TensorType
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
level=logging.INFO
)
# Define scale and size mappings
name_of_tensor_type = {
0: "FLOAT32",
9: "INT8 ",
17: "INT4 ",
}
dtype_for_tensor_type = {
0: torch.float32,
9: torch.int8,
17: torch.uint8, # Because torch.int4 doesn't exist
}
size_for_tensor_type = {
0: 4,
9: 1,
17: 0.5,
}
# Function to update target tensor names
def update_target_name(target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""
def reverse_replace(theStr: str, a, b):
return theStr.replace(b, a)
target_name = reverse_replace(target_name, ".weight", ".w")
target_name = reverse_replace(target_name,
"model.layers.", "params.lm.transformer.x_layers_"
)
target_name = reverse_replace(target_name,
"mlp.gate_proj", "ff_layer.ffn_layer1_gate"
)
target_name = reverse_replace(target_name, "mlp.up_proj", "ff_layer.ffn_layer1")
target_name = reverse_replace(target_name, "mlp.down_proj", "ff_layer.ffn_layer2")
target_name = reverse_replace(target_name,
"post_layer_norm.weight", "post_layer_norm.scale"
)
target_name = reverse_replace(target_name,
"post_attention_layernorm", "post_layer_norm"
)
target_name = reverse_replace(target_name,
"pre_layer_norm.weight", "pre_layer_norm.scale"
)
target_name = reverse_replace(target_name, "input_layernorm", "pre_layer_norm")
target_name = reverse_replace(target_name, "self_attn.q_proj", "self_attention.q")
target_name = reverse_replace(target_name, "self_attn.k_proj", "self_attention.k")
target_name = reverse_replace(target_name, "self_attn.v_proj", "self_attention.v")
target_name = reverse_replace(target_name, "self_attn.o_proj", "self_attention.post")
target_name = reverse_replace(target_name,
"model.embed_tokens", "params.lm.softmax.logits_ffn"
)
target_name = reverse_replace(target_name, "final_ln.weight", "final_ln.scale")
target_name = reverse_replace(target_name, "model.norm", "params.lm.final_ln")
return target_name
# Optimized dequantization for INT4
def convert_quantized_int4_to_fp(quantized_data, scale_data, dims, dim_scale, dtype):
zero_point = 8
# Reshape quantized data to 1D tensor
quantized_data = quantized_data.view(-1)
# Extract low and high 4 bits
low_bits = (quantized_data & 0x0F).type(torch.int8)
high_bits = (quantized_data >> 4).type(torch.int8)
# Concatenate low and high bits
int4_values = torch.stack((low_bits, high_bits), dim=1).view(-1)
int4_values = int4_values - zero_point # Adjust zero point
# Apply scaling
scaled_data = int4_values.type(dtype) * scale_data
# Reshape to original dimensions
scaled_data = scaled_data.view(dims[0], dims[1])
return scaled_data
# Function to dequantize INT8
def convert_quantized_int8_to_fp(quantized_data, scale_data, dims, dim_scale, dtype):
zero_point = 0 # Assuming zero_point=0 for int8
# Reshape quantized data to 1D tensor
quantized_data = quantized_data.view(-1).type(torch.int8)
# Handle scale_data based on dim_scale
if dim_scale:
# Per-column scaling
scale_data = scale_data.repeat_interleave(2)
else:
# Per-row scaling
scale_data = scale_data.repeat_interleave(2)
# Convert scale_data to the same dtype
scale_data = scale_data.to(dtype=dtype)
# Apply scaling
scaled_data = (quantized_data - zero_point).type(dtype) * scale_data
# Reshape to original dimensions
scaled_data = scaled_data.view(dims[0], dims[1])
return scaled_data
def main():
# Check command-line arguments
if len(sys.argv) < 3:
print("Usage: python converter.py <path_to_tflite_model> <output_safetensors_file> [fp32|fp16|bf16]")
sys.exit(1)
tflite_model_path = sys.argv[1]
output_safetensors_path = sys.argv[2]
dtype_arg = sys.argv[3] if len(sys.argv) >= 4 else "fp32"
if dtype_arg == "fp32":
TARGET_DTYPE = torch.float32
elif dtype_arg == "fp16":
TARGET_DTYPE = torch.float16
elif dtype_arg == "bf16":
TARGET_DTYPE = torch.bfloat16
else:
print("Unsupported dtype. Choose from fp32, fp16, bf16.")
sys.exit(1)
logger.info(f"Starting conversion with TARGET_DTYPE={TARGET_DTYPE}")
# Read the TFLite model
with open(tflite_model_path, "rb") as input_file:
buf = bytearray(input_file.read())
model: tflite.Model.Model = tflite.Model.Model.GetRootAs(buf)
graph: tflite.SubGraph.SubGraph = model.Subgraphs(0)
# Initialize dictionaries to hold tensors
i4_tensors = {}
i8_tensors = {}
fp32_tensors = {}
scale_tensors = {}
tensor_dims = {}
# Read and sort tensors
for i in range(graph.TensorsLength()):
tensor = graph.Tensors(i)
tensor_name = tensor.Name().decode("utf-8")
tensor_type: TensorType = tensor.Type()
if tensor_name.endswith(".w_quantized_scale"):
scale_tensors[tensor_name] = tensor
elif tensor_type == TensorType.INT4:
i4_tensors[tensor_name] = tensor
elif tensor_type == TensorType.INT8:
i8_tensors[tensor_name] = tensor
elif tensor_type == TensorType.FLOAT32:
fp32_tensors[tensor_name] = tensor
tensor_buf_size = tensor.Shape(0)
tensor_size = tensor_buf_size // size_for_tensor_type[tensor_type]
shape = None
if (".self_attention.q." in tensor_name
or ".self_attention.post." in tensor_name) and tensor_size == 4_194_304:
shape = (2048, 2048)
elif (".self_attention.k." in tensor_name
or ".self_attention.v." in tensor_name) and tensor_size == 524_288:
shape = (256, 2048)
elif (".ff_layer.ffn_layer1_gate." in tensor_name
or ".ff_layer.ffn_layer1." in tensor_name) and tensor_size == 25_165_824:
shape = (12_288, 2048)
elif ".ff_layer.ffn_layer2." in tensor_name and tensor_size == 25_165_824:
shape = (2048, 12_288)
elif "params.lm.softmax.logits_ffn.w" == tensor_name and tensor_size == 524_550_144:
shape = (256_128, 2048)
# LayerNorm weights are of shape {1, 1, 2048}
elif "layer_norm" in tensor_name and tensor_size == 2048:
shape = (1, 1, 2048)
else:
# Default to 1D if shape is unknown
pass
tensor_dims[tensor_name] = shape
# Dictionary to hold dequantized tensors
tensor_dict = {}
# Dequantize FP32 tensors
for tensor_name, tensor in fp32_tensors.items():
logger.info(f"Saving fp32 {tensor_name}...")
buffer_meta = model.Buffers(tensor.Buffer())
dims = tensor_dims.get(tensor_name)
target_name = update_target_name(tensor_name)
tensor_data = torch.frombuffer(buffer=buf,
dtype=torch.float32,
offset=buffer_meta.Offset(),
count=buffer_meta.Size() // 4)
# Assign reshaped tensor back
if dims is not None:
tensor_data = tensor_data.reshape(dims)
if TARGET_DTYPE != torch.float32:
tensor_data = tensor_data.to(dtype=TARGET_DTYPE)
tensor_dict[target_name] = tensor_data
del fp32_tensors
# Dequantize INT8 tensors
for tensor_name, quantized_tensor in i8_tensors.items():
buffer_meta = model.Buffers(quantized_tensor.Buffer())
scale_tensor_name = tensor_name + "_quantized_scale"
scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer())
dims = tensor_dims.get(tensor_name)
logger.info(f"Dequantizing int8 {dims} {tensor_name}...")
target_name = update_target_name(tensor_name)
quantized_buf = torch.frombuffer(buffer=buf,
dtype=torch.int8,
offset=buffer_meta.Offset(),
count=buffer_meta.Size())
scale_buf = torch.frombuffer(buffer=buf,
dtype=torch.float32,
offset=scale_buf_meta.Offset(),
count=scale_buf_meta.Size() // 4)
# MediaPipe TfLiteWeightAccessor::BuildWeightsMapFromTfliteModel sets
# dim_scale=0, so we do the same.
tensor_data = convert_quantized_int8_to_fp(
quantized_data=quantized_buf,
scale_data=scale_buf,
dims=dims,
dim_scale=0,
dtype=TARGET_DTYPE
)
tensor_dict[target_name] = tensor_data
del quantized_buf, scale_buf
del i8_tensors
# Dequantize INT4 tensors
for tensor_name, quantized_tensor in i4_tensors.items():
buffer_meta = model.Buffers(quantized_tensor.Buffer())
scale_tensor_name = tensor_name + "_quantized_scale"
scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer())
dims = tensor_dims.get(tensor_name)
logger.info(f"Dequantizing int4 {dims} {tensor_name}...")
target_name = update_target_name(tensor_name)
quantized_buf = torch.frombuffer(buffer=buf,
dtype=torch.uint8,
offset=buffer_meta.Offset(),
count=buffer_meta.Size())
scale_buf = torch.frombuffer(buffer=buf,
dtype=torch.float32,
offset=scale_buf_meta.Offset(),
count=scale_buf_meta.Size() // 4)
# Special handling for 'logits_ffn.w_quantized_scale'
if 'logits_ffn.w_quantized_scale' in tensor_name:
# Assuming two scale factors per row, average them
if scale_buf.numel() % 2 != 0:
logger.error(f"Scale data size for {tensor_name} is not even. Cannot average.")
sys.exit(1)
scale_data = scale_buf.view(-1, 2).mean(dim=1) # Average every two scale factors
# Repeat each scale factor twice to match the two int4 values
scale_data = scale_data.repeat_interleave(2)
else:
# General handling: per-row scaling, repeat each scale factor twice
scale_data = scale_buf.repeat_interleave(2)
# Convert and reshape quantized_data
tensor_data = convert_quantized_int4_to_fp(
quantized_data=quantized_buf,
scale_data=scale_data,
dims=dims,
dim_scale=0,
dtype=TARGET_DTYPE
)
tensor_dict[target_name] = tensor_data
del quantized_buf, scale_buf
del i4_tensors
del scale_tensors
del buf, model, graph
# Save all tensors to the safetensors file
logger.info(f"Saving to {output_safetensors_path}...")
st.save_file(tensor_dict, output_safetensors_path)
logger.info(f"Success! Saved to {output_safetensors_path}")
if __name__ == "__main__":
main()