|
import gradio as gr |
|
from attacks.mist import update_args_with_config, main |
|
|
|
''' |
|
TODO: |
|
- SDXL |
|
- model changing |
|
''' |
|
|
|
|
|
def process_image(eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
|
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
|
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight): |
|
|
|
config = (eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
|
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
|
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight) |
|
args = None |
|
args = update_args_with_config(args, config) |
|
main(args) |
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Image("MIST_logo.png", show_label=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
eps = gr.Slider(0, 32, step=1, value=10, label='Strength', |
|
info="Larger strength results in stronger but more visible defense.") |
|
device = gr.Radio(["cpu", "gpu"], value="gpu", label="Device", |
|
info="If you do not have good GPUs on your PC, choose 'CPU'.") |
|
|
|
|
|
resize = gr.Checkbox(value=True, label="Resizing the output image to the original resolution") |
|
mode = gr.Radio(["Mode 1", "Mode 2", "Mode 3"], value="Mode 1", label="Mode", |
|
info="Two modes both work with different visualization.") |
|
|
|
|
|
data_path = gr.Textbox(label="Data Path", lines=1, placeholder="Path to your images") |
|
output_path = gr.Textbox(label="Output Path", lines=1, placeholder="Path to store the outputs") |
|
model_path = gr.Textbox(label="Target Model Path", lines=1, placeholder="Path to the target model") |
|
class_path = gr.Textbox(label="Path to place contrast images ", lines=1, placeholder="Path to the target model") |
|
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Describe your images") |
|
|
|
with gr.Accordion("Professional Setups", open=False): |
|
class_prompt = gr.Textbox(label="Class prompt", lines=1, placeholder="Prompt for contrast images.") |
|
max_train_steps = gr.Slider(1, 20, step=1, value=5, label='Epochs', |
|
info="Training epochs of Mist-V2") |
|
max_f_train_steps = gr.Slider(0, 30, step=1, value=10, label='LoRA Steps', |
|
info="Training steps of LoRA in one epoch") |
|
max_adv_train_steps = gr.Slider(0, 100, step=5, value=30, label='Attacking Steps', |
|
info="Training steps of attacking in one epoch") |
|
lora_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of LoRA", value=0.0001) |
|
pgd_lr = gr.Number(minimum=0.0, maximum=1.0, label="The learning rate of PGD", value=0.005) |
|
rank = gr.Slider(4, 32, step=4, value=4, label='LoRA Ranks', |
|
info="Ranks of LoRA (Bigger ranks need better GPUs)") |
|
prior_loss_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of prior loss", value=0.1) |
|
fused_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of vae loss", value=0.00001) |
|
constraint_mode = gr.Radio(["Epsilon", "LPIPS"], value="Epsilon", label="Constraint Mode", |
|
info="The mode to constraint the watermark") |
|
lpips_bound = gr.Number(minimum=0.0, maximum=0.2, label="The LPIPS bound", value=0.1) |
|
lpips_weight = gr.Number(minimum=0.0, maximum=1.0, label="The weight of LPIPI constraint", value=0.5) |
|
|
|
|
|
|
|
|
|
inputs = [eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \ |
|
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \ |
|
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight] |
|
|
|
|
|
image_button = gr.Button("Mist") |
|
|
|
|
|
image_button.click(process_image, inputs=inputs) |
|
|
|
demo.queue().launch(share=True) |
|
|