Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import random | |
from typing import Tuple, Optional | |
import gradio as gr | |
from huggingface_hub import HfApi | |
from inf import InferencePipeline | |
SAMPLE_MODEL_IDS = [ | |
'lora-library/B-LoRA-teddybear', | |
'lora-library/B-LoRA-bull', | |
'lora-library/B-LoRA-wolf_plushie', | |
'lora-library/B-LoRA-pen_sketch', | |
'lora-library/B-LoRA-cartoon_line', | |
'lora-library/B-LoRA-multi-dog2', | |
] | |
css = """ | |
body { | |
font-size: 30px; | |
} | |
.gr-image { | |
width: 512px; | |
height: 512px; | |
object-fit: contain; | |
margin: auto; | |
} | |
.lora-column { | |
display: flex; | |
flex-direction: column; | |
align-items: center; /* Center align content vertically in columns */ | |
justify-content: center; /* Center content horizontally in columns */ | |
} | |
.gr-row { | |
align-items: center; | |
justify-content: center; | |
margin-top: 5px; | |
} | |
""" | |
def get_choices(hf_token): | |
api = HfApi(token=hf_token) | |
choices = [ | |
info.modelId for info in api.list_models(author='lora-library') | |
] | |
models_list = ['None'] + SAMPLE_MODEL_IDS + choices | |
return models_list | |
def get_image_from_card(card, model_id) -> Optional[str]: | |
try: | |
card_path = f"https://huggingface.co/{model_id}/resolve/main/" | |
widget = card.data.get('widget') | |
if widget is not None or len(widget) > 0: | |
output = widget[0].get('output') | |
if output is not None: | |
url = output.get('url') | |
if url is not None: | |
return card_path + url | |
return None | |
except Exception: | |
return None | |
def demo_init(): | |
try: | |
choices = get_choices(app.hf_token) | |
content_blora = random.choice(SAMPLE_MODEL_IDS) | |
style_blora = random.choice(SAMPLE_MODEL_IDS) | |
content_blora_prompt, content_blora_image = app.load_model_info(content_blora) | |
style_blora_prompt, style_blora_image = app.load_model_info(style_blora) | |
content_lora_model_id = gr.update(choices=choices, value=content_blora) | |
content_prompt = gr.update(value=content_blora_prompt) | |
content_image = gr.update(value=content_blora_image) | |
style_lora_model_id = gr.update(choices=choices, value=style_blora) | |
style_prompt = gr.update(value=style_blora_prompt) | |
style_image = gr.update(value=style_blora_image) | |
prompt = gr.update( | |
value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') | |
return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt | |
except Exception as e: | |
raise type(e)(f'failed to demo_init, due to: {e}') | |
def toggle_column(is_checked): | |
try: | |
return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) | |
except Exception as e: | |
raise type(e)(f'failed to toggle_column, due to: {e}') | |
class InferenceUtil: | |
def __init__(self, hf_token: str | None): | |
self.hf_token = hf_token | |
def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: | |
try: | |
try: | |
card = InferencePipeline.get_model_card(lora_model_id, | |
self.hf_token) | |
except Exception: | |
return '', None | |
instance_prompt = getattr(card.data, 'instance_prompt', '') | |
image_url = get_image_from_card(card, lora_model_id) | |
return instance_prompt, image_url | |
except Exception as e: | |
raise type(e)(f'failed to load_model_info, due to: {e}') | |
def update_model_info(self, model_source: str): | |
try: | |
if model_source == 'None': | |
return '', None | |
else: | |
model_info = self.load_model_info(model_source) | |
new_prompt, new_image = model_info[0], model_info[1] | |
return new_prompt, new_image | |
except Exception as e: | |
raise type(e)(f'failed to update_model_info, due to: {e}') | |
def create_inference_demo(pipe, #: InferencePipeline, | |
hf_token: str | None = None) -> gr.Blocks: | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(elem_classes="gr-row"): | |
with gr.Column(): | |
with gr.Group(elem_classes="lora-column"): | |
gr.Markdown('## Content B-LoRA') | |
content_checkbox = gr.Checkbox(label='Use Content Only', value=False) | |
content_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) | |
content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1) | |
content_image = gr.Image(label='Content Image', elem_classes="gr-image") | |
with gr.Column(): | |
with gr.Group(elem_classes="lora-column"): | |
gr.Markdown('## Style B-LoRA') | |
style_checkbox = gr.Checkbox(label='Use Style Only', value=False) | |
style_lora_model_id = gr.Dropdown(label='Model ID', choices=[]) | |
style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1) | |
style_image = gr.Image(label='Style Image', elem_classes="gr-image") | |
with gr.Row(elem_classes="gr-row"): | |
with gr.Column(): | |
with gr.Group(): | |
prompt = gr.Textbox( | |
label='Prompt', | |
max_lines=1, | |
placeholder='Example: "A [c] in [s] style"' | |
) | |
result = gr.Image(label='Result') | |
with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"): | |
content_alpha = gr.Slider(label='Content B-LoRA alpha', | |
minimum=0, | |
maximum=2, | |
step=0.05, | |
value=1) | |
style_alpha = gr.Slider(label='Style B-LoRA alpha', | |
minimum=0, | |
maximum=2, | |
step=0.05, | |
value=1) | |
seed = gr.Slider(label='Seed', | |
minimum=0, | |
maximum=100000, | |
step=1, | |
value=8888) | |
num_steps = gr.Slider(label='Number of Steps', | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=50) | |
guidance_scale = gr.Slider(label='CFG Scale', | |
minimum=0, | |
maximum=50, | |
step=0.1, | |
value=7.5) | |
run_button = gr.Button('Generate') | |
demo.load(demo_init, inputs=[], | |
outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, | |
style_image, prompt], queue=False, show_progress="hidden") | |
content_lora_model_id.change( | |
fn=app.update_model_info, | |
inputs=content_lora_model_id, | |
outputs=[ | |
content_prompt, | |
content_image, | |
]) | |
style_lora_model_id.change( | |
fn=app.update_model_info, | |
inputs=style_lora_model_id, | |
outputs=[ | |
style_prompt, | |
style_image, | |
]) | |
style_prompt.change( | |
fn=lambda content_blora_prompt, | |
style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if style_blora_prompt else content_blora_prompt, | |
inputs=[content_prompt, style_prompt], | |
outputs=prompt, | |
) | |
content_prompt.change( | |
fn=lambda content_blora_prompt, | |
style_blora_prompt: f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if content_blora_prompt else style_blora_prompt, | |
inputs=[content_prompt, style_prompt], | |
outputs=prompt, | |
) | |
content_checkbox.change(toggle_column, inputs=[content_checkbox], | |
outputs=[style_lora_model_id]) | |
style_checkbox.change(toggle_column, inputs=[style_checkbox], | |
outputs=[content_lora_model_id]) | |
inputs = [ | |
content_lora_model_id, | |
style_lora_model_id, | |
prompt, | |
content_alpha, | |
style_alpha, | |
seed, | |
num_steps, | |
guidance_scale, | |
] | |
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) | |
run_button.click(fn=pipe.run, inputs=inputs, outputs=result) | |
return demo | |
if __name__ == '__main__': | |
hf_token = os.getenv('HF_TOKEN') | |
pipe = InferencePipeline(hf_token) | |
app = InferenceUtil(hf_token) | |
demo = create_inference_demo(pipe, hf_token) | |
demo.queue(max_size=10).launch(share=False) | |