text_img / app.py
Brij1808's picture
Create app.py
500c32a verified
import requests
import base64
from PIL import Image, ImageFilter
from io import BytesIO
from transformers import pipeline
import streamlit as st
# API Endpoint for image generation
url = "http://34.198.214.220:8000/generate/"
# Streamlit sidebar for model selection
model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0)
st.title("Text to Image Generator")
# Streamlit input field for prompt
prompt = st.text_input("Enter Prompt", "") # Default prompt
# Use a pipeline as a high-level helper for image classification
pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier")
def classify_image(image):
"""
Classifies an image using the NSFW classifier.
Args:
image: The PIL image object to be classified.
Returns:
A dictionary containing the classification results.
"""
try:
# Classify the image using the pipeline
results = pipe(image)
return results
except Exception as e:
st.error(f"Error during classification: {e}")
return None
def blur_image(image):
"""
Applies a Gaussian Blur to an image and saves it.
Args:
image: The PIL image object to be blurred.
"""
# Apply Gaussian Blur filter to the image
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40))
# Display the blurred image
st.image(blurred_image, caption="Blurred Image", use_container_width=True)
def process_image(image):
"""
Processes the image by classifying it and applying actions based on results.
Args:
image: The PIL image object.
"""
results = classify_image(image)
if results:
# Check if either 'porn' label > 0.7 or 'sexy' label > 0.85
porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0)
sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0)
if porn_score > 0.7 or sexy_score > 0.85:
blur_image(image) # Apply blur and show the blurred image
else:
st.image(image, caption="Original Image", use_container_width=True) # Show the original image even if it does not meet the threshold
else:
st.error("Error: Image classification failed.")
# Button to generate image
if st.button('Generate Image'):
payload = {
"prompt": prompt, # User input prompt
"model": model_option # Model selected by user
}
# Generate the image using the API
response = requests.post(url, json=payload)
# Check if the request was successful
if response.status_code == 200:
response_data = response.json()
# Extract the base64 image string
if "image_base64" in response_data:
base64_string = response_data["image_base64"]
# Decode the base64 string into an image
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
# Process the generated image
process_image(image)
else:
st.error("No image data found in the response!")
else:
st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}")