import requests import base64 from PIL import Image, ImageFilter from io import BytesIO from transformers import pipeline import streamlit as st # API Endpoint for image generation url = "http://34.198.214.220:8000/generate/" # Streamlit sidebar for model selection model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0) st.title("Text to Image Generator") # Streamlit input field for prompt prompt = st.text_input("Enter Prompt", "") # Default prompt # Use a pipeline as a high-level helper for image classification pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier") def classify_image(image): """ Classifies an image using the NSFW classifier. Args: image: The PIL image object to be classified. Returns: A dictionary containing the classification results. """ try: # Classify the image using the pipeline results = pipe(image) return results except Exception as e: st.error(f"Error during classification: {e}") return None def blur_image(image): """ Applies a Gaussian Blur to an image and saves it. Args: image: The PIL image object to be blurred. """ # Apply Gaussian Blur filter to the image blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40)) # Display the blurred image st.image(blurred_image, caption="Blurred Image", use_container_width=True) def process_image(image): """ Processes the image by classifying it and applying actions based on results. Args: image: The PIL image object. """ results = classify_image(image) if results: # Check if either 'porn' label > 0.7 or 'sexy' label > 0.85 porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0) sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0) if porn_score > 0.7 or sexy_score > 0.85: blur_image(image) # Apply blur and show the blurred image else: st.image(image, caption="Original Image", use_container_width=True) # Show the original image even if it does not meet the threshold else: st.error("Error: Image classification failed.") # Button to generate image if st.button('Generate Image'): payload = { "prompt": prompt, # User input prompt "model": model_option # Model selected by user } # Generate the image using the API response = requests.post(url, json=payload) # Check if the request was successful if response.status_code == 200: response_data = response.json() # Extract the base64 image string if "image_base64" in response_data: base64_string = response_data["image_base64"] # Decode the base64 string into an image image_data = base64.b64decode(base64_string) image = Image.open(BytesIO(image_data)) # Process the generated image process_image(image) else: st.error("No image data found in the response!") else: st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}")