#!/usr/bin/env python from __future__ import annotations import os import random from typing import Tuple, Optional import gradio as gr from huggingface_hub import HfApi from inf import InferencePipeline SAMPLE_MODEL_IDS = [ 'lora-library/B-LoRA-teddybear', 'lora-library/B-LoRA-bull', 'lora-library/B-LoRA-wolf_plushie', 'lora-library/B-LoRA-pen_sketch', 'lora-library/B-LoRA-cartoon_line', 'lora-library/B-LoRA-child', 'lora-library/B-LoRA-multi-dog2', ] css = """ .gradio-container { max-width: 1250px !important; } #title { text-align: center; } #title h1 { font-size: 250%; } .lora-title { background-image: linear-gradient(to right, #314755 0%, #26a0da 51%, #314755 100%); text-align: center; border-radius: 10px; display: block; } .lora-title h2 { color: white !important; } .gr-image { width: 512px; height: 512px; object-fit: contain; margin: auto; } .res-image { object-fit: contain; margin: auto; } .lora-column { display: flex; flex-direction: column; align-items: center; justify-content: center; border: none; background: none; } .gr-row { align-items: center; justify-content: center; margin-top: 5px; } """ def get_choices(hf_token): api = HfApi(token=hf_token) choices = [ info.modelId for info in api.list_models(author='lora-library') ] models_list = ['None'] + SAMPLE_MODEL_IDS + choices return models_list def get_image_from_card(card, model_id) -> Optional[str]: try: card_path = f"https://huggingface.co/{model_id}/resolve/main/" widget = card.data.get('widget') if widget is not None or len(widget) > 0: output = widget[0].get('output') if output is not None: url = output.get('url') if url is not None: return card_path + url return None except Exception: return None def demo_init(): try: choices = get_choices(app.hf_token) content_blora = random.choice(SAMPLE_MODEL_IDS) style_blora = random.choice(SAMPLE_MODEL_IDS) content_blora_prompt, content_blora_image = app.load_model_info(content_blora) style_blora_prompt, style_blora_image = app.load_model_info(style_blora) content_lora_model_id = gr.update(choices=choices, value=content_blora) content_prompt = gr.update(value=content_blora_prompt) content_image = gr.update(value=content_blora_image) style_lora_model_id = gr.update(choices=choices, value=style_blora) style_prompt = gr.update(value=style_blora_prompt) style_image = gr.update(value=style_blora_image) prompt = gr.update( value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt except Exception as e: raise type(e)(f'failed to demo_init, due to: {e}') def toggle_column(is_checked): try: return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) except Exception as e: raise type(e)(f'failed to toggle_column, due to: {e}') def handle_prompt_change(content_blora_prompt, style_blora_prompt) -> str: try: if content_blora_prompt and style_blora_prompt: return f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if content_blora_prompt: return content_blora_prompt if style_blora_prompt: return f'A dog in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' return '' except Exception as e: raise type(e)(f'failed to handle_prompt_change, due to: {e}') class InferenceUtil: def __init__(self, hf_token: str | None): self.hf_token = hf_token def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: try: try: card = InferencePipeline.get_model_card(lora_model_id, self.hf_token) except Exception: return '', None instance_prompt = getattr(card.data, 'instance_prompt', '') image_url = get_image_from_card(card, lora_model_id) return instance_prompt, image_url except Exception as e: raise type(e)(f'failed to load_model_info, due to: {e}') def update_model_info(self, model_source: str): try: if model_source == 'None': return '', None else: model_info = self.load_model_info(model_source) new_prompt, new_image = model_info[0], model_info[1] return new_prompt, new_image except Exception as e: raise type(e)(f'failed to update_model_info, due to: {e}') hf_token = os.getenv('HF_TOKEN') pipe = InferencePipeline(hf_token) app = InferenceUtil(hf_token) with gr.Blocks(css=css) as demo: title = gr.HTML( '''

Implicit Style-Content Separation using B-LoRA

This is a demo for our paper: ''Implicit Style-Content Separation using B-LoRA''.
Project page and code is available here.

''', elem_id="title" ) with gr.Row(elem_classes="gr-row"): with gr.Column(): with gr.Group(elem_classes="lora-column"): content_sub_title = gr.HTML('''

Content B-LoRA

''', elem_classes="lora-title") content_checkbox = gr.Checkbox(label='Use Content Only', value=False) content_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1) content_image = gr.Image(label='Content Image', elem_classes="gr-image") with gr.Column(): with gr.Group(elem_classes="lora-column"): style_sub_title = gr.HTML('''

Style B-LoRA

''', elem_classes="lora-title") style_checkbox = gr.Checkbox(label='Use Style Only', value=False) style_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1) style_image = gr.Image(label='Style Image', elem_classes="gr-image") with gr.Row(elem_classes="gr-row"): with gr.Column(): with gr.Group(): prompt = gr.Textbox( label='Prompt', max_lines=1, placeholder='Example: "A [c] in [s] style"' ) result = gr.Gallery(label='Result', elem_classes="res-image") with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"): content_alpha = gr.Slider(label='Content B-LoRA alpha', minimum=0, maximum=2, step=0.05, value=1) style_alpha = gr.Slider(label='Style B-LoRA alpha', minimum=0, maximum=2, step=0.05, value=1) seed = gr.Slider(label='Seed', minimum=0, maximum=100000, step=1, value=8888) num_steps = gr.Slider(label='Number of Steps', minimum=0, maximum=100, step=1, value=40) guidance_scale = gr.Slider(label='CFG Scale', minimum=0, maximum=50, step=0.1, value=7.5) num_images_per_prompt = gr.Slider(label='Number of Images per Prompt', minimum=1, maximum=4, step=1, value=2) run_button = gr.Button('Generate') demo.load(demo_init, inputs=[], outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt], queue=False, show_progress="hidden") content_lora_model_id.change( fn=app.update_model_info, inputs=content_lora_model_id, outputs=[ content_prompt, content_image, ]) style_lora_model_id.change( fn=app.update_model_info, inputs=style_lora_model_id, outputs=[ style_prompt, style_image, ]) style_prompt.change( fn=handle_prompt_change, inputs=[content_prompt, style_prompt], outputs=prompt, ) content_prompt.change( fn=handle_prompt_change, inputs=[content_prompt, style_prompt], outputs=prompt, ) content_checkbox.change(toggle_column, inputs=[content_checkbox], outputs=[style_lora_model_id]) style_checkbox.change(toggle_column, inputs=[style_checkbox], outputs=[content_lora_model_id]) inputs = [ content_lora_model_id, style_lora_model_id, prompt, content_alpha, style_alpha, seed, num_steps, guidance_scale, num_images_per_prompt ] prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) run_button.click(fn=pipe.run, inputs=inputs, outputs=result) demo.queue(max_size=10).launch(share=False)