|
|
|
|
|
import sys |
|
import torch |
|
import safetensors.torch as st |
|
import logging |
|
import math |
|
import tflite.Model |
|
import tflite.SubGraph |
|
from tflite.TensorType import TensorType |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig( |
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
level=logging.INFO |
|
) |
|
|
|
|
|
name_of_tensor_type = { |
|
0: "FLOAT32", |
|
9: "INT8 ", |
|
17: "INT4 ", |
|
} |
|
|
|
dtype_for_tensor_type = { |
|
0: torch.float32, |
|
9: torch.int8, |
|
17: torch.uint8, |
|
} |
|
|
|
size_for_tensor_type = { |
|
0: 4, |
|
9: 1, |
|
17: 0.5, |
|
} |
|
|
|
|
|
def update_target_name(target_name: str) -> str: |
|
"""Updates the target name to match the tensor name convention.""" |
|
def reverse_replace(theStr: str, a, b): |
|
return theStr.replace(b, a) |
|
|
|
target_name = reverse_replace(target_name, ".weight", ".w") |
|
target_name = reverse_replace(target_name, |
|
"model.layers.", "params.lm.transformer.x_layers_" |
|
) |
|
|
|
target_name = reverse_replace(target_name, |
|
"mlp.gate_proj", "ff_layer.ffn_layer1_gate" |
|
) |
|
target_name = reverse_replace(target_name, "mlp.up_proj", "ff_layer.ffn_layer1") |
|
target_name = reverse_replace(target_name, "mlp.down_proj", "ff_layer.ffn_layer2") |
|
|
|
target_name = reverse_replace(target_name, |
|
"post_layer_norm.weight", "post_layer_norm.scale" |
|
) |
|
target_name = reverse_replace(target_name, |
|
"post_attention_layernorm", "post_layer_norm" |
|
) |
|
|
|
target_name = reverse_replace(target_name, |
|
"pre_layer_norm.weight", "pre_layer_norm.scale" |
|
) |
|
target_name = reverse_replace(target_name, "input_layernorm", "pre_layer_norm") |
|
|
|
target_name = reverse_replace(target_name, "self_attn.q_proj", "self_attention.q") |
|
target_name = reverse_replace(target_name, "self_attn.k_proj", "self_attention.k") |
|
target_name = reverse_replace(target_name, "self_attn.v_proj", "self_attention.v") |
|
target_name = reverse_replace(target_name, "self_attn.o_proj", "self_attention.post") |
|
target_name = reverse_replace(target_name, |
|
"model.embed_tokens", "params.lm.softmax.logits_ffn" |
|
) |
|
target_name = reverse_replace(target_name, "final_ln.weight", "final_ln.scale") |
|
target_name = reverse_replace(target_name, "model.norm", "params.lm.final_ln") |
|
|
|
return target_name |
|
|
|
|
|
def convert_quantized_int4_to_fp(quantized_data, scale_data, dims, dim_scale, dtype): |
|
zero_point = 8 |
|
|
|
|
|
quantized_data = quantized_data.view(-1) |
|
|
|
|
|
low_bits = (quantized_data & 0x0F).type(torch.int8) |
|
high_bits = (quantized_data >> 4).type(torch.int8) |
|
|
|
|
|
int4_values = torch.stack((low_bits, high_bits), dim=1).view(-1) |
|
int4_values = int4_values - zero_point |
|
|
|
|
|
scaled_data = int4_values.type(dtype) * scale_data |
|
|
|
|
|
scaled_data = scaled_data.view(dims[0], dims[1]) |
|
|
|
return scaled_data |
|
|
|
|
|
def convert_quantized_int8_to_fp(quantized_data, scale_data, dims, dim_scale, dtype): |
|
zero_point = 0 |
|
|
|
|
|
quantized_data = quantized_data.view(-1).type(torch.int8) |
|
|
|
|
|
if dim_scale: |
|
|
|
scale_data = scale_data.repeat_interleave(2) |
|
else: |
|
|
|
scale_data = scale_data.repeat_interleave(2) |
|
|
|
|
|
scale_data = scale_data.to(dtype=dtype) |
|
|
|
|
|
scaled_data = (quantized_data - zero_point).type(dtype) * scale_data |
|
|
|
|
|
scaled_data = scaled_data.view(dims[0], dims[1]) |
|
|
|
return scaled_data |
|
|
|
def main(): |
|
|
|
if len(sys.argv) < 3: |
|
print("Usage: python converter.py <path_to_tflite_model> <output_safetensors_file> [fp32|fp16|bf16]") |
|
sys.exit(1) |
|
|
|
tflite_model_path = sys.argv[1] |
|
output_safetensors_path = sys.argv[2] |
|
dtype_arg = sys.argv[3] if len(sys.argv) >= 4 else "fp32" |
|
|
|
if dtype_arg == "fp32": |
|
TARGET_DTYPE = torch.float32 |
|
elif dtype_arg == "fp16": |
|
TARGET_DTYPE = torch.float16 |
|
elif dtype_arg == "bf16": |
|
TARGET_DTYPE = torch.bfloat16 |
|
else: |
|
print("Unsupported dtype. Choose from fp32, fp16, bf16.") |
|
sys.exit(1) |
|
|
|
logger.info(f"Starting conversion with TARGET_DTYPE={TARGET_DTYPE}") |
|
|
|
|
|
with open(tflite_model_path, "rb") as input_file: |
|
buf = bytearray(input_file.read()) |
|
|
|
model: tflite.Model.Model = tflite.Model.Model.GetRootAs(buf) |
|
graph: tflite.SubGraph.SubGraph = model.Subgraphs(0) |
|
|
|
|
|
i4_tensors = {} |
|
i8_tensors = {} |
|
fp32_tensors = {} |
|
scale_tensors = {} |
|
tensor_dims = {} |
|
|
|
|
|
for i in range(graph.TensorsLength()): |
|
tensor = graph.Tensors(i) |
|
tensor_name = tensor.Name().decode("utf-8") |
|
tensor_type: TensorType = tensor.Type() |
|
|
|
if tensor_name.endswith(".w_quantized_scale"): |
|
scale_tensors[tensor_name] = tensor |
|
elif tensor_type == TensorType.INT4: |
|
i4_tensors[tensor_name] = tensor |
|
elif tensor_type == TensorType.INT8: |
|
i8_tensors[tensor_name] = tensor |
|
elif tensor_type == TensorType.FLOAT32: |
|
fp32_tensors[tensor_name] = tensor |
|
|
|
tensor_buf_size = tensor.Shape(0) |
|
tensor_size = tensor_buf_size // size_for_tensor_type[tensor_type] |
|
|
|
shape = None |
|
if (".self_attention.q." in tensor_name |
|
or ".self_attention.post." in tensor_name) and tensor_size == 4_194_304: |
|
shape = (2048, 2048) |
|
elif (".self_attention.k." in tensor_name |
|
or ".self_attention.v." in tensor_name) and tensor_size == 524_288: |
|
shape = (256, 2048) |
|
elif (".ff_layer.ffn_layer1_gate." in tensor_name |
|
or ".ff_layer.ffn_layer1." in tensor_name) and tensor_size == 25_165_824: |
|
shape = (12_288, 2048) |
|
elif ".ff_layer.ffn_layer2." in tensor_name and tensor_size == 25_165_824: |
|
shape = (2048, 12_288) |
|
elif "params.lm.softmax.logits_ffn.w" == tensor_name and tensor_size == 524_550_144: |
|
shape = (256_128, 2048) |
|
|
|
elif "layer_norm" in tensor_name and tensor_size == 2048: |
|
shape = (1, 1, 2048) |
|
else: |
|
|
|
pass |
|
|
|
tensor_dims[tensor_name] = shape |
|
|
|
|
|
tensor_dict = {} |
|
|
|
|
|
for tensor_name, tensor in fp32_tensors.items(): |
|
logger.info(f"Saving fp32 {tensor_name}...") |
|
buffer_meta = model.Buffers(tensor.Buffer()) |
|
dims = tensor_dims.get(tensor_name) |
|
|
|
target_name = update_target_name(tensor_name) |
|
|
|
tensor_data = torch.frombuffer(buffer=buf, |
|
dtype=torch.float32, |
|
offset=buffer_meta.Offset(), |
|
count=buffer_meta.Size() // 4) |
|
|
|
|
|
if dims is not None: |
|
tensor_data = tensor_data.reshape(dims) |
|
|
|
if TARGET_DTYPE != torch.float32: |
|
tensor_data = tensor_data.to(dtype=TARGET_DTYPE) |
|
|
|
tensor_dict[target_name] = tensor_data |
|
|
|
del fp32_tensors |
|
|
|
|
|
for tensor_name, quantized_tensor in i8_tensors.items(): |
|
buffer_meta = model.Buffers(quantized_tensor.Buffer()) |
|
scale_tensor_name = tensor_name + "_quantized_scale" |
|
scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer()) |
|
dims = tensor_dims.get(tensor_name) |
|
|
|
logger.info(f"Dequantizing int8 {dims} {tensor_name}...") |
|
|
|
target_name = update_target_name(tensor_name) |
|
|
|
quantized_buf = torch.frombuffer(buffer=buf, |
|
dtype=torch.int8, |
|
offset=buffer_meta.Offset(), |
|
count=buffer_meta.Size()) |
|
|
|
scale_buf = torch.frombuffer(buffer=buf, |
|
dtype=torch.float32, |
|
offset=scale_buf_meta.Offset(), |
|
count=scale_buf_meta.Size() // 4) |
|
|
|
|
|
|
|
tensor_data = convert_quantized_int8_to_fp( |
|
quantized_data=quantized_buf, |
|
scale_data=scale_buf, |
|
dims=dims, |
|
dim_scale=0, |
|
dtype=TARGET_DTYPE |
|
) |
|
|
|
tensor_dict[target_name] = tensor_data |
|
|
|
del quantized_buf, scale_buf |
|
|
|
del i8_tensors |
|
|
|
|
|
for tensor_name, quantized_tensor in i4_tensors.items(): |
|
buffer_meta = model.Buffers(quantized_tensor.Buffer()) |
|
scale_tensor_name = tensor_name + "_quantized_scale" |
|
scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer()) |
|
dims = tensor_dims.get(tensor_name) |
|
|
|
logger.info(f"Dequantizing int4 {dims} {tensor_name}...") |
|
|
|
target_name = update_target_name(tensor_name) |
|
|
|
quantized_buf = torch.frombuffer(buffer=buf, |
|
dtype=torch.uint8, |
|
offset=buffer_meta.Offset(), |
|
count=buffer_meta.Size()) |
|
|
|
scale_buf = torch.frombuffer(buffer=buf, |
|
dtype=torch.float32, |
|
offset=scale_buf_meta.Offset(), |
|
count=scale_buf_meta.Size() // 4) |
|
|
|
|
|
if 'logits_ffn.w_quantized_scale' in tensor_name: |
|
|
|
if scale_buf.numel() % 2 != 0: |
|
logger.error(f"Scale data size for {tensor_name} is not even. Cannot average.") |
|
sys.exit(1) |
|
scale_data = scale_buf.view(-1, 2).mean(dim=1) |
|
|
|
scale_data = scale_data.repeat_interleave(2) |
|
else: |
|
|
|
scale_data = scale_buf.repeat_interleave(2) |
|
|
|
|
|
tensor_data = convert_quantized_int4_to_fp( |
|
quantized_data=quantized_buf, |
|
scale_data=scale_data, |
|
dims=dims, |
|
dim_scale=0, |
|
dtype=TARGET_DTYPE |
|
) |
|
|
|
tensor_dict[target_name] = tensor_data |
|
|
|
del quantized_buf, scale_buf |
|
|
|
del i4_tensors |
|
del scale_tensors |
|
|
|
del buf, model, graph |
|
|
|
|
|
logger.info(f"Saving to {output_safetensors_path}...") |
|
st.save_file(tensor_dict, output_safetensors_path) |
|
logger.info(f"Success! Saved to {output_safetensors_path}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|