File size: 4,378 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import argparse
import sys
import tensorrt as trt
def convert_models(onnx_path: str, num_controlnet: int, output_path: str, fp16: bool = False, sd_xl: bool = False):
"""
Function to convert models in stable diffusion controlnet pipeline into TensorRT format
Example:
python convert_stable_diffusion_controlnet_to_tensorrt.py
--onnx_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.onnx
--output_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.engine
--fp16
--num_controlnet 2
Example for SD XL:
python convert_stable_diffusion_controlnet_to_tensorrt.py
--onnx_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.onnx
--output_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
--fp16
--num_controlnet 1
--sd_xl
Returns:
unet/model.engine
run test script in diffusers/examples/community
python test_onnx_controlnet.py
--sd_model danbrown/RevAnimated-v1-2-2
--onnx_model_dir path-to-models-stable_diffusion/RevAnimated-v1-2-2
--unet_engine_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
--qr_img_path path-to-qr-code-image
"""
# UNET
if sd_xl:
batch_size = 1
unet_in_channels = 4
unet_sample_size = 64
num_tokens = 77
text_hidden_size = 2048
img_size = 512
text_embeds_shape = (2 * batch_size, 1280)
time_ids_shape = (2 * batch_size, 6)
else:
batch_size = 1
unet_in_channels = 4
unet_sample_size = 64
num_tokens = 77
text_hidden_size = 768
img_size = 512
batch_size = 1
latents_shape = (2 * batch_size, unet_in_channels, unet_sample_size, unet_sample_size)
embed_shape = (2 * batch_size, num_tokens, text_hidden_size)
controlnet_conds_shape = (num_controlnet, 2 * batch_size, 3, img_size, img_size)
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT_BUILDER = trt.Builder(TRT_LOGGER)
TRT_RUNTIME = trt.Runtime(TRT_LOGGER)
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
parse_success = onnx_parser.parse_from_file(onnx_path)
for idx in range(onnx_parser.num_errors):
print(onnx_parser.get_error(idx))
if not parse_success:
sys.exit("ONNX model parsing failed")
print("Load Onnx model done")
profile = TRT_BUILDER.create_optimization_profile()
profile.set_shape("sample", latents_shape, latents_shape, latents_shape)
profile.set_shape("encoder_hidden_states", embed_shape, embed_shape, embed_shape)
profile.set_shape("controlnet_conds", controlnet_conds_shape, controlnet_conds_shape, controlnet_conds_shape)
if sd_xl:
profile.set_shape("text_embeds", text_embeds_shape, text_embeds_shape, text_embeds_shape)
profile.set_shape("time_ids", time_ids_shape, time_ids_shape, time_ids_shape)
config = TRT_BUILDER.create_builder_config()
config.add_optimization_profile(profile)
config.set_preview_feature(trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805, True)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
plan = TRT_BUILDER.build_serialized_network(network, config)
if plan is None:
sys.exit("Failed building engine")
print("Succeeded building engine")
engine = TRT_RUNTIME.deserialize_cuda_engine(plan)
## save TRT engine
with open(output_path, "wb") as f:
f.write(engine.serialize())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sd_xl", action="store_true", default=False, help="SD XL pipeline")
parser.add_argument(
"--onnx_path",
type=str,
required=True,
help="Path to the onnx checkpoint to convert",
)
parser.add_argument("--num_controlnet", type=int)
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
args = parser.parse_args()
convert_models(args.onnx_path, args.num_controlnet, args.output_path, args.fp16, args.sd_xl)
|