Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,809 Bytes
07a421e 23dca80 07a421e d2d56e8 b5b4791 07a421e a960bc2 b5b4791 7785249 c66e22e 85913ad 07a421e b5b4791 7785249 7ca8bcd b5b4791 7785249 3654a3e a59bcf0 07a421e a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 4715e52 a960bc2 07a421e 23dca80 07a421e 4715e52 07a421e 7785249 4715e52 7785249 07a421e 7ca8bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces
device=torch.device('cuda')
# Load the model and LoRA weights
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
pipe.fuse_lora(lora_scale=1.5)
pipe.to("cuda")
# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
#text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)
NSFW_THRESHOLD = 0.3
# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt, num_inference_steps, guidance_scale):
# Classify the text for NSFW content
#text_classification = text_classifier(prompt)
#print(text_classification)
# Check the classification results
#for result in text_classification:
# if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
# return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
print(prompt)
image = pipe("sketched style, " + prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
# Classify the image for NSFW content
image_classification = image_classifier(image)
print(image_classification)
# Check the classification results
for result in image_classification:
if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
return None,"Inappropriate content detected. Please try another prompt." #return gr.update(visible=False),gr.Text(value="Inappropriate content detected. Please try another prompt.")
image_path = "generated_sketch.png"
image.save(image_path)
return image_path,None #gr.Image(value=image_path), gr.update(visible=False)
# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
fn=generate_sketch,
inputs=[
"text", # Prompt input
gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
],
outputs=[
gr.Image(label="Generated Sketch"),
gr.Textbox(label="Message")
],
title="Kids Sketch Generator",
description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)
# Launch the app
interface.launch() |