B-LoRA / app_inference.py
Yardenfren's picture
Update app_inference.py
193ed2d verified
raw
history blame
11.4 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import random
from typing import Tuple, Optional
import gradio as gr
from huggingface_hub import HfApi
from inf import InferencePipeline
SAMPLE_MODEL_IDS = [
'lora-library/B-LoRA-teddybear',
'lora-library/B-LoRA-bull',
'lora-library/B-LoRA-wolf_plushie',
'lora-library/B-LoRA-pen_sketch',
'lora-library/B-LoRA-cartoon_line',
'lora-library/B-LoRA-child',
'lora-library/B-LoRA-vase',
'lora-library/B-LoRA-scary_mug',
'lora-library/B-LoRA-statue',
'lora-library/B-LoRA-colorful_teapot',
'lora-library/B-LoRA-grey_sloth_plushie',
'lora-library/B-LoRA-teapot',
'lora-library/B-LoRA-backpack_dog',
'lora-library/B-LoRA-buddha',
'lora-library/B-LoRA-dog6',
'lora-library/B-LoRA-poop_emoji',
'lora-library/B-LoRA-pot',
'lora-library/B-LoRA-fat_bird',
'lora-library/B-LoRA-elephant',
'lora-library/B-LoRA-metal_bird',
'lora-library/B-LoRA-cat',
'lora-library/B-LoRA-dog2',
'lora-library/B-LoRA-drawing1',
'lora-library/B-LoRA-village_oil',
'lora-library/B-LoRA-watercolor',
'lora-library/B-LoRA-house_3d',
'lora-library/B-LoRA-ink_sketch',
'lora-library/B-LoRA-drawing3',
'lora-library/B-LoRA-crayon_drawing',
'lora-library/B-LoRA-kiss',
'lora-library/B-LoRA-drawing4',
'lora-library/B-LoRA-working_cartoon',
'lora-library/B-LoRA-painting',
'lora-library/B-LoRA-drawing2'
'lora-library/B-LoRA-multi-dog2',
]
css = """
.gradio-container {
max-width: 900px !important;
}
#title {
text-align: center;
}
#title h1 {
font-size: 250%;
}
.lora-title {
background-image: linear-gradient(to right, #314755 0%, #26a0da 51%, #314755 100%);
text-align: center;
border-radius: 10px;
display: block;
}
.lora-title h2 {
color: white !important;
}
.gr-image {
width: 256px;
height: 256px;
object-fit: contain;
margin: auto;
}
.res-image {
object-fit: contain;
margin: auto;
}
.lora-column {
border: none;
background: none;
}
.gr-row {
align-items: center;
justify-content: center;
margin-top: 5px;
}
.svelte-iyf88w {
background: none;
}
"""
def get_choices(hf_token):
api = HfApi(token=hf_token)
choices = [
info.modelId for info in api.list_models(author='lora-library')
]
models_list = ['None'] + SAMPLE_MODEL_IDS + choices
return models_list
def get_image_from_card(card, model_id) -> Optional[str]:
try:
card_path = f"https://huggingface.co/{model_id}/resolve/main/"
widget = card.data.get('widget')
if widget is not None or len(widget) > 0:
output = widget[0].get('output')
if output is not None:
url = output.get('url')
if url is not None:
return card_path + url
return None
except Exception:
return None
def demo_init():
try:
choices = get_choices(app.hf_token)
content_blora = random.choice(SAMPLE_MODEL_IDS)
style_blora = random.choice(SAMPLE_MODEL_IDS)
content_blora_prompt, content_blora_image = app.load_model_info(content_blora)
style_blora_prompt, style_blora_image = app.load_model_info(style_blora)
content_lora_model_id = gr.update(choices=choices, value=content_blora)
content_prompt = gr.update(value=content_blora_prompt)
content_image = gr.update(value=content_blora_image)
style_lora_model_id = gr.update(choices=choices, value=style_blora)
style_prompt = gr.update(value=style_blora_prompt)
style_image = gr.update(value=style_blora_image)
prompt = gr.update(
value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style')
return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt
except Exception as e:
raise type(e)(f'failed to demo_init, due to: {e}')
def toggle_column(is_checked):
try:
return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS)
except Exception as e:
raise type(e)(f'failed to toggle_column, due to: {e}')
def handle_prompt_change(content_blora_prompt, style_blora_prompt) -> str:
try:
if content_blora_prompt and style_blora_prompt:
return f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style'
if content_blora_prompt:
return content_blora_prompt
if style_blora_prompt:
return f'A dog in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style'
return ''
except Exception as e:
raise type(e)(f'failed to handle_prompt_change, due to: {e}')
class InferenceUtil:
def __init__(self, hf_token: str | None):
self.hf_token = hf_token
def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]:
try:
try:
card = InferencePipeline.get_model_card(lora_model_id,
self.hf_token)
except Exception:
return '', None
instance_prompt = getattr(card.data, 'instance_prompt', '')
image_url = get_image_from_card(card, lora_model_id)
return instance_prompt, image_url
except Exception as e:
raise type(e)(f'failed to load_model_info, due to: {e}')
def update_model_info(self, model_source: str):
try:
if model_source == 'None':
return '', None
else:
model_info = self.load_model_info(model_source)
new_prompt, new_image = model_info[0], model_info[1]
return new_prompt, new_image
except Exception as e:
raise type(e)(f'failed to update_model_info, due to: {e}')
hf_token = os.getenv('HF_TOKEN')
pipe = InferencePipeline(hf_token)
app = InferenceUtil(hf_token)
with gr.Blocks(css=css) as demo:
title = gr.HTML(
'''<h1>Implicit Style-Content Separation using B-LoRA</h1>
<p>This is a demo for our <a href="https://arxiv.org/abs/2403.14572">paper</a>: <b>''Implicit Style-Content Separation using B-LoRA''</b>.
<br>
Project page and code is available <a href="https://b-lora.github.io/B-LoRA/">here</a>.</p>
''',
elem_id="title"
)
with gr.Row(elem_classes="gr-row"):
with gr.Column():
with gr.Group(elem_classes="lora-column"):
content_sub_title = gr.HTML('''<h2>Content B-LoRA</h2>''', elem_classes="lora-title")
content_checkbox = gr.Checkbox(label='Use Content Only', value=False)
content_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1)
content_image = gr.Image(label='Content Image', elem_classes="gr-image")
with gr.Column():
with gr.Group(elem_classes="lora-column"):
style_sub_title = gr.HTML('''<h2>Style B-LoRA</h2>''', elem_classes="lora-title")
style_checkbox = gr.Checkbox(label='Use Style Only', value=False)
style_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1)
style_image = gr.Image(label='Style Image', elem_classes="gr-image")
with gr.Row(elem_classes="gr-row"):
with gr.Column():
with gr.Group():
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "A [c] in [s] style"'
)
result = gr.Gallery(label='Result', elem_classes="res-image")
with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"):
content_alpha = gr.Slider(label='Content B-LoRA alpha',
minimum=0,
maximum=2,
step=0.05,
value=1)
style_alpha = gr.Slider(label='Style B-LoRA alpha',
minimum=0,
maximum=2,
step=0.05,
value=1)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=8888)
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=40)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
num_images_per_prompt = gr.Slider(label='Number of Images per Prompt',
minimum=1,
maximum=4,
step=1,
value=2)
run_button = gr.Button('Generate')
demo.load(demo_init, inputs=[],
outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt,
style_image, prompt], queue=False, show_progress="hidden")
content_lora_model_id.change(
fn=app.update_model_info,
inputs=content_lora_model_id,
outputs=[
content_prompt,
content_image,
])
style_lora_model_id.change(
fn=app.update_model_info,
inputs=style_lora_model_id,
outputs=[
style_prompt,
style_image,
])
style_prompt.change(
fn=handle_prompt_change,
inputs=[content_prompt, style_prompt],
outputs=prompt,
)
content_prompt.change(
fn=handle_prompt_change,
inputs=[content_prompt, style_prompt],
outputs=prompt,
)
content_checkbox.change(toggle_column, inputs=[content_checkbox],
outputs=[style_lora_model_id])
style_checkbox.change(toggle_column, inputs=[style_checkbox],
outputs=[content_lora_model_id])
inputs = [
content_lora_model_id,
style_lora_model_id,
prompt,
content_alpha,
style_alpha,
seed,
num_steps,
guidance_scale,
num_images_per_prompt
]
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
demo.queue(max_size=10).launch(share=False)