|
|
|
import requests |
|
import base64 |
|
from PIL import Image, ImageFilter |
|
from io import BytesIO |
|
from transformers import pipeline |
|
import streamlit as st |
|
|
|
|
|
url = "http://34.198.214.220:8000/generate/" |
|
|
|
|
|
model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0) |
|
|
|
st.title("Text to Image Generator") |
|
|
|
|
|
prompt = st.text_input("Enter Prompt", "") |
|
|
|
|
|
pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier") |
|
|
|
def classify_image(image): |
|
""" |
|
Classifies an image using the NSFW classifier. |
|
|
|
Args: |
|
image: The PIL image object to be classified. |
|
|
|
Returns: |
|
A dictionary containing the classification results. |
|
""" |
|
try: |
|
|
|
results = pipe(image) |
|
return results |
|
except Exception as e: |
|
st.error(f"Error during classification: {e}") |
|
return None |
|
|
|
def blur_image(image): |
|
""" |
|
Applies a Gaussian Blur to an image and saves it. |
|
|
|
Args: |
|
image: The PIL image object to be blurred. |
|
""" |
|
|
|
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40)) |
|
|
|
|
|
st.image(blurred_image, caption="Blurred Image", use_container_width=True) |
|
|
|
def process_image(image): |
|
""" |
|
Processes the image by classifying it and applying actions based on results. |
|
|
|
Args: |
|
image: The PIL image object. |
|
""" |
|
results = classify_image(image) |
|
|
|
if results: |
|
|
|
|
|
porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0) |
|
sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0) |
|
if porn_score > 0.7 or sexy_score > 0.85: |
|
blur_image(image) |
|
else: |
|
st.image(image, caption="Original Image", use_container_width=True) |
|
else: |
|
st.error("Error: Image classification failed.") |
|
|
|
if st.button('Generate Image'): |
|
payload = { |
|
"prompt": prompt, |
|
"model": model_option |
|
} |
|
|
|
|
|
response = requests.post(url, json=payload) |
|
|
|
|
|
if response.status_code == 200: |
|
response_data = response.json() |
|
|
|
|
|
if "image_base64" in response_data: |
|
base64_string = response_data["image_base64"] |
|
|
|
|
|
image_data = base64.b64decode(base64_string) |
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
process_image(image) |
|
|
|
else: |
|
st.error("No image data found in the response!") |
|
else: |
|
st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}") |
|
|