import threading


buffer = []
outputs = []
is_working = False


def worker():
    global buffer, outputs, is_working

    import time
    import shared
    import random
    import modules.default_pipeline as pipeline
    import modules.path
    import modules.patch

    from modules.sdxl_styles import apply_style, aspect_ratios
    from modules.private_logger import log

    try:
        async_gradio_app = shared.gradio_root
        flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
        if async_gradio_app.share:
            flag += f''' or {async_gradio_app.share_url}'''
        print(flag)
    except Exception as e:
        print(e)

    def handler(task):
        prompt, style_selection = task
        steps = 30
        switch = 20
        aspect_ratios_selection = '1280×768'
        seed = random.randint(1, int(1024*1024*1024))
        sharpness = 10.0

        loras=[(modules.path.default_lora_name, modules.path.default_lora_weight), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]
        modules.patch.sharpness = sharpness
        pipeline.refresh_base_model(modules.path.default_base_model_name)
        pipeline.refresh_refiner_model(modules.path.default_refiner_model_name)
        pipeline.refresh_loras(loras)
        pipeline.clean_prompt_cond_caches()

        p_txt, n_txt = apply_style(style_selection, prompt)
        width, height = aspect_ratios[aspect_ratios_selection]
        results = []
        
        def callback(step, x0, x, total_steps, y):
            done_steps = step
            outputs.append(['preview', (
                int(100.0 * float(done_steps) / float(steps)),
                f'{step}/{total_steps}',
                y)])

        img = pipeline.process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback)[0]

        d = [
            ('Prompt', prompt),
            ('Style', style_selection),
            ('Seed', seed)
        ]
        for n, w in loras:
            if n != 'None':
                d.append((f'LoRA [{n}] weight', w))
        img_path=log(img, d)

        outputs.append(['results', [img, img_path]])
        return

    while True:
        time.sleep(0.01)
        if len(buffer) > 0:
            is_working=True
            task = buffer.pop(0)
            handler(task)
            is_working=False
    pass


threading.Thread(target=worker, daemon=True).start()