text_img / app.py
Brij1808's picture
Create app.py
500c32a verified
raw
history blame
3.28 kB
import requests
import base64
from PIL import Image, ImageFilter
from io import BytesIO
from transformers import pipeline
import streamlit as st
# API Endpoint for image generation
url = "http://34.198.214.220:8000/generate/"
# Streamlit sidebar for model selection
model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0)
st.title("Text to Image Generator")
# Streamlit input field for prompt
prompt = st.text_input("Enter Prompt", "") # Default prompt
# Use a pipeline as a high-level helper for image classification
pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier")
def classify_image(image):
"""
Classifies an image using the NSFW classifier.
Args:
image: The PIL image object to be classified.
Returns:
A dictionary containing the classification results.
"""
try:
# Classify the image using the pipeline
results = pipe(image)
return results
except Exception as e:
st.error(f"Error during classification: {e}")
return None
def blur_image(image):
"""
Applies a Gaussian Blur to an image and saves it.
Args:
image: The PIL image object to be blurred.
"""
# Apply Gaussian Blur filter to the image
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40))
# Display the blurred image
st.image(blurred_image, caption="Blurred Image", use_container_width=True)
def process_image(image):
"""
Processes the image by classifying it and applying actions based on results.
Args:
image: The PIL image object.
"""
results = classify_image(image)
if results:
# Check if either 'porn' label > 0.7 or 'sexy' label > 0.85
porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0)
sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0)
if porn_score > 0.7 or sexy_score > 0.85:
blur_image(image) # Apply blur and show the blurred image
else:
st.image(image, caption="Original Image", use_container_width=True) # Show the original image even if it does not meet the threshold
else:
st.error("Error: Image classification failed.")
# Button to generate image
if st.button('Generate Image'):
payload = {
"prompt": prompt, # User input prompt
"model": model_option # Model selected by user
}
# Generate the image using the API
response = requests.post(url, json=payload)
# Check if the request was successful
if response.status_code == 200:
response_data = response.json()
# Extract the base64 image string
if "image_base64" in response_data:
base64_string = response_data["image_base64"]
# Decode the base64 string into an image
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
# Process the generated image
process_image(image)
else:
st.error("No image data found in the response!")
else:
st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}")