#!/usr/bin/env python from __future__ import annotations import os import random from typing import Tuple, Optional import gradio as gr from huggingface_hub import HfApi from inf import InferencePipeline SAMPLE_MODEL_IDS = [ 'lora-library/B-LoRA-teddybear', 'lora-library/B-LoRA-bull', 'lora-library/B-LoRA-wolf_plushie', 'lora-library/B-LoRA-pen_sketch', 'lora-library/B-LoRA-cartoon_line', 'lora-library/B-LoRA-child', 'lora-library/B-LoRA-vase', 'lora-library/B-LoRA-scary_mug', 'lora-library/B-LoRA-statue', 'lora-library/B-LoRA-colorful_teapot', 'lora-library/B-LoRA-grey_sloth_plushie', 'lora-library/B-LoRA-teapot', 'lora-library/B-LoRA-backpack_dog', 'lora-library/B-LoRA-buddha', 'lora-library/B-LoRA-dog6', 'lora-library/B-LoRA-poop_emoji', 'lora-library/B-LoRA-pot', 'lora-library/B-LoRA-fat_bird', 'lora-library/B-LoRA-elephant', 'lora-library/B-LoRA-metal_bird', 'lora-library/B-LoRA-cat', 'lora-library/B-LoRA-dog2', 'lora-library/B-LoRA-drawing1', 'lora-library/B-LoRA-village_oil', 'lora-library/B-LoRA-watercolor', 'lora-library/B-LoRA-house_3d', 'lora-library/B-LoRA-ink_sketch', 'lora-library/B-LoRA-drawing3', 'lora-library/B-LoRA-crayon_drawing', 'lora-library/B-LoRA-kiss', 'lora-library/B-LoRA-drawing4', 'lora-library/B-LoRA-working_cartoon', 'lora-library/B-LoRA-painting', 'lora-library/B-LoRA-drawing2' 'lora-library/B-LoRA-multi-dog2', ] css = """ .gradio-container { max-width: 1250px !important; } #title { text-align: center; } #title h1 { font-size: 250%; } .lora-title { background-image: linear-gradient(to right, #314755 0%, #26a0da 51%, #314755 100%); text-align: center; border-radius: 10px; display: block; } .lora-title h2 { color: white !important; } .gr-image { width: 512px; height: 512px; object-fit: contain; margin: auto; } .res-image { object-fit: contain; margin: auto; } .lora-column { display: flex; flex-direction: column; align-items: center; justify-content: center; border: none; background: none; } .gr-row { align-items: center; justify-content: center; margin-top: 5px; } """ def get_choices(hf_token): api = HfApi(token=hf_token) choices = [ info.modelId for info in api.list_models(author='lora-library') ] models_list = ['None'] + SAMPLE_MODEL_IDS + choices return models_list def get_image_from_card(card, model_id) -> Optional[str]: try: card_path = f"https://huggingface.co/{model_id}/resolve/main/" widget = card.data.get('widget') if widget is not None or len(widget) > 0: output = widget[0].get('output') if output is not None: url = output.get('url') if url is not None: return card_path + url return None except Exception: return None def demo_init(): try: choices = get_choices(app.hf_token) content_blora = random.choice(SAMPLE_MODEL_IDS) style_blora = random.choice(SAMPLE_MODEL_IDS) content_blora_prompt, content_blora_image = app.load_model_info(content_blora) style_blora_prompt, style_blora_image = app.load_model_info(style_blora) content_lora_model_id = gr.update(choices=choices, value=content_blora) content_prompt = gr.update(value=content_blora_prompt) content_image = gr.update(value=content_blora_image) style_lora_model_id = gr.update(choices=choices, value=style_blora) style_prompt = gr.update(value=style_blora_prompt) style_image = gr.update(value=style_blora_image) prompt = gr.update( value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style') return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt except Exception as e: raise type(e)(f'failed to demo_init, due to: {e}') def toggle_column(is_checked): try: return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS) except Exception as e: raise type(e)(f'failed to toggle_column, due to: {e}') def handle_prompt_change(content_blora_prompt, style_blora_prompt) -> str: try: if content_blora_prompt and style_blora_prompt: return f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' if content_blora_prompt: return content_blora_prompt if style_blora_prompt: return f'A dog in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style' return '' except Exception as e: raise type(e)(f'failed to handle_prompt_change, due to: {e}') class InferenceUtil: def __init__(self, hf_token: str | None): self.hf_token = hf_token def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]: try: try: card = InferencePipeline.get_model_card(lora_model_id, self.hf_token) except Exception: return '', None instance_prompt = getattr(card.data, 'instance_prompt', '') image_url = get_image_from_card(card, lora_model_id) return instance_prompt, image_url except Exception as e: raise type(e)(f'failed to load_model_info, due to: {e}') def update_model_info(self, model_source: str): try: if model_source == 'None': return '', None else: model_info = self.load_model_info(model_source) new_prompt, new_image = model_info[0], model_info[1] return new_prompt, new_image except Exception as e: raise type(e)(f'failed to update_model_info, due to: {e}') hf_token = os.getenv('HF_TOKEN') pipe = InferencePipeline(hf_token) app = InferenceUtil(hf_token) with gr.Blocks(css=css) as demo: title = gr.HTML( '''
This is a demo for our paper: ''Implicit Style-Content Separation using B-LoRA''.
Project page and code is available here.