File size: 4,378 Bytes
43b7e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import sys

import tensorrt as trt


def convert_models(onnx_path: str, num_controlnet: int, output_path: str, fp16: bool = False, sd_xl: bool = False):
    """
    Function to convert models in stable diffusion controlnet pipeline into TensorRT format

    Example:
    python convert_stable_diffusion_controlnet_to_tensorrt.py
    --onnx_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.onnx
    --output_path path-to-models-stable_diffusion/RevAnimated-v1-2-2/unet/model.engine
    --fp16
    --num_controlnet 2

    Example for SD XL:
    python convert_stable_diffusion_controlnet_to_tensorrt.py
    --onnx_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.onnx
    --output_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
    --fp16
    --num_controlnet 1
    --sd_xl

    Returns:
        unet/model.engine

        run test script in diffusers/examples/community
        python test_onnx_controlnet.py
        --sd_model danbrown/RevAnimated-v1-2-2
        --onnx_model_dir path-to-models-stable_diffusion/RevAnimated-v1-2-2
        --unet_engine_path path-to-models-stable_diffusion/stable-diffusion-xl-base-1.0/unet/model.engine
        --qr_img_path path-to-qr-code-image
    """
    # UNET
    if sd_xl:
        batch_size = 1
        unet_in_channels = 4
        unet_sample_size = 64
        num_tokens = 77
        text_hidden_size = 2048
        img_size = 512

        text_embeds_shape = (2 * batch_size, 1280)
        time_ids_shape = (2 * batch_size, 6)
    else:
        batch_size = 1
        unet_in_channels = 4
        unet_sample_size = 64
        num_tokens = 77
        text_hidden_size = 768
        img_size = 512
        batch_size = 1

    latents_shape = (2 * batch_size, unet_in_channels, unet_sample_size, unet_sample_size)
    embed_shape = (2 * batch_size, num_tokens, text_hidden_size)
    controlnet_conds_shape = (num_controlnet, 2 * batch_size, 3, img_size, img_size)

    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
    TRT_BUILDER = trt.Builder(TRT_LOGGER)
    TRT_RUNTIME = trt.Runtime(TRT_LOGGER)

    network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    onnx_parser = trt.OnnxParser(network, TRT_LOGGER)

    parse_success = onnx_parser.parse_from_file(onnx_path)
    for idx in range(onnx_parser.num_errors):
        print(onnx_parser.get_error(idx))
    if not parse_success:
        sys.exit("ONNX model parsing failed")
    print("Load Onnx model done")

    profile = TRT_BUILDER.create_optimization_profile()

    profile.set_shape("sample", latents_shape, latents_shape, latents_shape)
    profile.set_shape("encoder_hidden_states", embed_shape, embed_shape, embed_shape)
    profile.set_shape("controlnet_conds", controlnet_conds_shape, controlnet_conds_shape, controlnet_conds_shape)
    if sd_xl:
        profile.set_shape("text_embeds", text_embeds_shape, text_embeds_shape, text_embeds_shape)
        profile.set_shape("time_ids", time_ids_shape, time_ids_shape, time_ids_shape)

    config = TRT_BUILDER.create_builder_config()
    config.add_optimization_profile(profile)
    config.set_preview_feature(trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805, True)
    if fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    plan = TRT_BUILDER.build_serialized_network(network, config)
    if plan is None:
        sys.exit("Failed building engine")
    print("Succeeded building engine")

    engine = TRT_RUNTIME.deserialize_cuda_engine(plan)

    ## save TRT engine
    with open(output_path, "wb") as f:
        f.write(engine.serialize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--sd_xl", action="store_true", default=False, help="SD XL pipeline")

    parser.add_argument(
        "--onnx_path",
        type=str,
        required=True,
        help="Path to the onnx checkpoint to convert",
    )

    parser.add_argument("--num_controlnet", type=int)

    parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

    parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")

    args = parser.parse_args()

    convert_models(args.onnx_path, args.num_controlnet, args.output_path, args.fp16, args.sd_xl)