File size: 7,380 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import ast
import base64
import os
import argparse
import sys
import uuid


def main():
    parser = argparse.ArgumentParser(description="Generate images from text prompts")
    parser.add_argument("--prompt", "--query", type=str, required=True, help="User prompt or query")
    parser.add_argument("--model", type=str, required=False, help="Model name")
    parser.add_argument("--output", "--file", type=str, required=False, default="",
                        help="Name (unique) of the output file")
    parser.add_argument("--quality", type=str, required=False, choices=['standard', 'hd', 'quick', 'manual'],
                        default='standard',
                        help="Image quality")
    parser.add_argument("--size", type=str, required=False, default="1024x1024", help="Image size (height x width)")

    imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '')
    assert imagegen_url is not None, "IMAGEGEN_OPENAI_BASE_URL environment variable is not set"
    server_api_key = os.getenv('IMAGEGEN_OPENAI_API_KEY', 'EMPTY')

    generation_params = {}

    is_openai = False
    if imagegen_url == "https://api.gpt.h2o.ai/v1":
        parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation")
        parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps")
        args = parser.parse_args()
        from openai import OpenAI
        client = OpenAI(base_url=imagegen_url, api_key=server_api_key)
        available_models = ['flux.1-schnell', 'playv2']
        if os.getenv('IMAGEGEN_OPENAI_MODELS'):
            # allow override
            available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS'))
        if not args.model:
            args.model = available_models[0]
        if args.model not in available_models:
            args.model = available_models[0]
    elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url:
        is_openai = True
        parser.add_argument("--style", type=str, choices=['vivid', 'natural', 'artistic'], default='vivid',
                            help="Image style")
        args = parser.parse_args()
        # https://platform.openai.com/docs/api-reference/images/create
        available_models = ['dall-e-3', 'dall-e-2']
        # assumes deployment name matches model name, unless override
        if os.getenv('IMAGEGEN_OPENAI_MODELS'):
            # allow override
            available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS'))
        if not args.model:
            args.model = available_models[0]
        if args.model not in available_models:
            args.model = available_models[0]

        if 'openai.azure.com' in imagegen_url:
            # https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line%2Ctypescript&pivots=programming-language-python
            from openai import AzureOpenAI
            client = AzureOpenAI(
                api_version="2024-02-01" if args.model == 'dall-e-3' else '2023-06-01-preview',
                api_key=os.environ["IMAGEGEN_OPENAI_API_KEY"],
                # like base_url, but Azure endpoint like https://PROJECT.openai.azure.com/
                azure_endpoint=os.environ['IMAGEGEN_OPENAI_BASE_URL']
            )
        else:
            from openai import OpenAI
            client = OpenAI(base_url=imagegen_url, api_key=server_api_key)

        dalle2aliases = ['dall-e-2', 'dalle2', 'dalle-2']
        max_chars = 1000 if args.model in dalle2aliases else 4000
        args.prompt = args.prompt[:max_chars]

        if args.model in dalle2aliases:
            valid_sizes = ['256x256', '512x512', '1024x1024']
        else:
            valid_sizes = ['1024x1024', '1792x1024', '1024x1792']

        if args.size not in valid_sizes:
            args.size = valid_sizes[0]

        args.quality = 'standard' if args.quality not in ['standard', 'hd'] else args.quality
        args.style = 'vivid' if args.style not in ['vivid', 'natural'] else args.style
        generation_params.update({
            "style": args.style,
        })
    else:
        parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation")
        parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps")
        args = parser.parse_args()

        from openai import OpenAI
        client = OpenAI(base_url=imagegen_url, api_key=server_api_key)
        assert os.getenv('IMAGEGEN_OPENAI_MODELS'), "IMAGEGEN_OPENAI_MODELS environment variable is not set"
        available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS'))  # must be string of list of strings
        assert available_models, "IMAGEGEN_OPENAI_MODELS environment variable is not set, must be for this server"
        if args.model is None:
            args.model = available_models[0]
        if args.model not in available_models:
            args.model = available_models[0]

    # for azure, args.model use assume deployment name matches model name (i.e. dall-e-3 not dalle3) unless IMAGEGEN_OPENAI_MODELS set
    generation_params.update({
        "prompt": args.prompt,
        "model": args.model,
        "quality": args.quality,
        "size": args.size,
        "response_format": "b64_json",
    })

    if not is_openai:
        extra_body = {}
        if args.guidance_scale:
            extra_body["guidance_scale"] = args.guidance_scale
        if args.num_inference_steps:
            extra_body["num_inference_steps"] = args.num_inference_steps
        if extra_body:
            generation_params["extra_body"] = extra_body

    response = client.images.generate(**generation_params)

    if hasattr(response.data[0], 'revised_prompt') and response.data[0].revised_prompt:
        print("Image Generator revised the prompt (this is expected): %s" % response.data[0].revised_prompt)

    assert response.data[0].b64_json is not None or response.data[0].url is not None, "No image data returned"

    if response.data[0].b64_json:
        image_data_base64 = response.data[0].b64_json
        image_data = base64.b64decode(image_data_base64)
    else:
        from openai_server.agent_tools.common.utils import download_simple
        dest = download_simple(response.data[0].url, overwrite=True)
        with open(dest, "rb") as f:
            image_data = f.read()
        os.remove(dest)

    # Determine file type and name
    image_format = get_image_format(image_data)
    if not args.output:
        args.output = f"image_{str(uuid.uuid4())[:6]}.{image_format}"
    else:
        # If an output path is provided, ensure it has the correct extension
        base, ext = os.path.splitext(args.output)
        if ext.lower() != f".{image_format}":
            args.output = f"{base}.{image_format}"

    # Write the image data to a file
    with open(args.output, "wb") as img_file:
        img_file.write(image_data)

    full_path = os.path.abspath(args.output)
    print(f"Image successfully saved to the file: {full_path}")

    # NOTE: Could provide stats like image size, etc.


def get_image_format(image_data):
    from PIL import Image
    import io
    # Use PIL to determine the image format
    with Image.open(io.BytesIO(image_data)) as img:
        return img.format.lower()


if __name__ == "__main__":
    main()