# converter.py import sys import torch import safetensors.torch as st import logging import math import tflite.Model import tflite.SubGraph from tflite.TensorType import TensorType # Set up logging logger = logging.getLogger(__name__) logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', level=logging.INFO ) # Define scale and size mappings name_of_tensor_type = { 0: "FLOAT32", 9: "INT8 ", 17: "INT4 ", } dtype_for_tensor_type = { 0: torch.float32, 9: torch.int8, 17: torch.uint8, # Because torch.int4 doesn't exist } size_for_tensor_type = { 0: 4, 9: 1, 17: 0.5, } # Function to update target tensor names def update_target_name(target_name: str) -> str: """Updates the target name to match the tensor name convention.""" def reverse_replace(theStr: str, a, b): return theStr.replace(b, a) target_name = reverse_replace(target_name, ".weight", ".w") target_name = reverse_replace(target_name, "model.layers.", "params.lm.transformer.x_layers_" ) target_name = reverse_replace(target_name, "mlp.gate_proj", "ff_layer.ffn_layer1_gate" ) target_name = reverse_replace(target_name, "mlp.up_proj", "ff_layer.ffn_layer1") target_name = reverse_replace(target_name, "mlp.down_proj", "ff_layer.ffn_layer2") target_name = reverse_replace(target_name, "post_layer_norm.weight", "post_layer_norm.scale" ) target_name = reverse_replace(target_name, "post_attention_layernorm", "post_layer_norm" ) target_name = reverse_replace(target_name, "pre_layer_norm.weight", "pre_layer_norm.scale" ) target_name = reverse_replace(target_name, "input_layernorm", "pre_layer_norm") target_name = reverse_replace(target_name, "self_attn.q_proj", "self_attention.q") target_name = reverse_replace(target_name, "self_attn.k_proj", "self_attention.k") target_name = reverse_replace(target_name, "self_attn.v_proj", "self_attention.v") target_name = reverse_replace(target_name, "self_attn.o_proj", "self_attention.post") target_name = reverse_replace(target_name, "model.embed_tokens", "params.lm.softmax.logits_ffn" ) target_name = reverse_replace(target_name, "final_ln.weight", "final_ln.scale") target_name = reverse_replace(target_name, "model.norm", "params.lm.final_ln") return target_name # Optimized dequantization for INT4 def convert_quantized_int4_to_fp(quantized_data, scale_data, dims, dim_scale, dtype): zero_point = 8 # Reshape quantized data to 1D tensor quantized_data = quantized_data.view(-1) # Extract low and high 4 bits low_bits = (quantized_data & 0x0F).type(torch.int8) high_bits = (quantized_data >> 4).type(torch.int8) # Concatenate low and high bits int4_values = torch.stack((low_bits, high_bits), dim=1).view(-1) int4_values = int4_values - zero_point # Adjust zero point # Apply scaling scaled_data = int4_values.type(dtype) * scale_data # Reshape to original dimensions scaled_data = scaled_data.view(dims[0], dims[1]) return scaled_data # Function to dequantize INT8 def convert_quantized_int8_to_fp(quantized_data, scale_data, dims, dim_scale, dtype): zero_point = 0 # Assuming zero_point=0 for int8 # Reshape quantized data to 1D tensor quantized_data = quantized_data.view(-1).type(torch.int8) # Handle scale_data based on dim_scale if dim_scale: # Per-column scaling scale_data = scale_data.repeat_interleave(2) else: # Per-row scaling scale_data = scale_data.repeat_interleave(2) # Convert scale_data to the same dtype scale_data = scale_data.to(dtype=dtype) # Apply scaling scaled_data = (quantized_data - zero_point).type(dtype) * scale_data # Reshape to original dimensions scaled_data = scaled_data.view(dims[0], dims[1]) return scaled_data def main(): # Check command-line arguments if len(sys.argv) < 3: print("Usage: python converter.py [fp32|fp16|bf16]") sys.exit(1) tflite_model_path = sys.argv[1] output_safetensors_path = sys.argv[2] dtype_arg = sys.argv[3] if len(sys.argv) >= 4 else "fp32" if dtype_arg == "fp32": TARGET_DTYPE = torch.float32 elif dtype_arg == "fp16": TARGET_DTYPE = torch.float16 elif dtype_arg == "bf16": TARGET_DTYPE = torch.bfloat16 else: print("Unsupported dtype. Choose from fp32, fp16, bf16.") sys.exit(1) logger.info(f"Starting conversion with TARGET_DTYPE={TARGET_DTYPE}") # Read the TFLite model with open(tflite_model_path, "rb") as input_file: buf = bytearray(input_file.read()) model: tflite.Model.Model = tflite.Model.Model.GetRootAs(buf) graph: tflite.SubGraph.SubGraph = model.Subgraphs(0) # Initialize dictionaries to hold tensors i4_tensors = {} i8_tensors = {} fp32_tensors = {} scale_tensors = {} tensor_dims = {} # Read and sort tensors for i in range(graph.TensorsLength()): tensor = graph.Tensors(i) tensor_name = tensor.Name().decode("utf-8") tensor_type: TensorType = tensor.Type() if tensor_name.endswith(".w_quantized_scale"): scale_tensors[tensor_name] = tensor elif tensor_type == TensorType.INT4: i4_tensors[tensor_name] = tensor elif tensor_type == TensorType.INT8: i8_tensors[tensor_name] = tensor elif tensor_type == TensorType.FLOAT32: fp32_tensors[tensor_name] = tensor tensor_buf_size = tensor.Shape(0) tensor_size = tensor_buf_size // size_for_tensor_type[tensor_type] shape = None if (".self_attention.q." in tensor_name or ".self_attention.post." in tensor_name) and tensor_size == 4_194_304: shape = (2048, 2048) elif (".self_attention.k." in tensor_name or ".self_attention.v." in tensor_name) and tensor_size == 524_288: shape = (256, 2048) elif (".ff_layer.ffn_layer1_gate." in tensor_name or ".ff_layer.ffn_layer1." in tensor_name) and tensor_size == 25_165_824: shape = (12_288, 2048) elif ".ff_layer.ffn_layer2." in tensor_name and tensor_size == 25_165_824: shape = (2048, 12_288) elif "params.lm.softmax.logits_ffn.w" == tensor_name and tensor_size == 524_550_144: shape = (256_128, 2048) # LayerNorm weights are of shape {1, 1, 2048} elif "layer_norm" in tensor_name and tensor_size == 2048: shape = (1, 1, 2048) else: # Default to 1D if shape is unknown pass tensor_dims[tensor_name] = shape # Dictionary to hold dequantized tensors tensor_dict = {} # Dequantize FP32 tensors for tensor_name, tensor in fp32_tensors.items(): logger.info(f"Saving fp32 {tensor_name}...") buffer_meta = model.Buffers(tensor.Buffer()) dims = tensor_dims.get(tensor_name) target_name = update_target_name(tensor_name) tensor_data = torch.frombuffer(buffer=buf, dtype=torch.float32, offset=buffer_meta.Offset(), count=buffer_meta.Size() // 4) # Assign reshaped tensor back if dims is not None: tensor_data = tensor_data.reshape(dims) if TARGET_DTYPE != torch.float32: tensor_data = tensor_data.to(dtype=TARGET_DTYPE) tensor_dict[target_name] = tensor_data del fp32_tensors # Dequantize INT8 tensors for tensor_name, quantized_tensor in i8_tensors.items(): buffer_meta = model.Buffers(quantized_tensor.Buffer()) scale_tensor_name = tensor_name + "_quantized_scale" scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer()) dims = tensor_dims.get(tensor_name) logger.info(f"Dequantizing int8 {dims} {tensor_name}...") target_name = update_target_name(tensor_name) quantized_buf = torch.frombuffer(buffer=buf, dtype=torch.int8, offset=buffer_meta.Offset(), count=buffer_meta.Size()) scale_buf = torch.frombuffer(buffer=buf, dtype=torch.float32, offset=scale_buf_meta.Offset(), count=scale_buf_meta.Size() // 4) # MediaPipe TfLiteWeightAccessor::BuildWeightsMapFromTfliteModel sets # dim_scale=0, so we do the same. tensor_data = convert_quantized_int8_to_fp( quantized_data=quantized_buf, scale_data=scale_buf, dims=dims, dim_scale=0, dtype=TARGET_DTYPE ) tensor_dict[target_name] = tensor_data del quantized_buf, scale_buf del i8_tensors # Dequantize INT4 tensors for tensor_name, quantized_tensor in i4_tensors.items(): buffer_meta = model.Buffers(quantized_tensor.Buffer()) scale_tensor_name = tensor_name + "_quantized_scale" scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer()) dims = tensor_dims.get(tensor_name) logger.info(f"Dequantizing int4 {dims} {tensor_name}...") target_name = update_target_name(tensor_name) quantized_buf = torch.frombuffer(buffer=buf, dtype=torch.uint8, offset=buffer_meta.Offset(), count=buffer_meta.Size()) scale_buf = torch.frombuffer(buffer=buf, dtype=torch.float32, offset=scale_buf_meta.Offset(), count=scale_buf_meta.Size() // 4) # Special handling for 'logits_ffn.w_quantized_scale' if 'logits_ffn.w_quantized_scale' in tensor_name: # Assuming two scale factors per row, average them if scale_buf.numel() % 2 != 0: logger.error(f"Scale data size for {tensor_name} is not even. Cannot average.") sys.exit(1) scale_data = scale_buf.view(-1, 2).mean(dim=1) # Average every two scale factors # Repeat each scale factor twice to match the two int4 values scale_data = scale_data.repeat_interleave(2) else: # General handling: per-row scaling, repeat each scale factor twice scale_data = scale_buf.repeat_interleave(2) # Convert and reshape quantized_data tensor_data = convert_quantized_int4_to_fp( quantized_data=quantized_buf, scale_data=scale_data, dims=dims, dim_scale=0, dtype=TARGET_DTYPE ) tensor_dict[target_name] = tensor_data del quantized_buf, scale_buf del i4_tensors del scale_tensors del buf, model, graph # Save all tensors to the safetensors file logger.info(f"Saving to {output_safetensors_path}...") st.save_file(tensor_dict, output_safetensors_path) logger.info(f"Success! Saved to {output_safetensors_path}") if __name__ == "__main__": main()