Spaces:
Runtime error
Runtime error
import gradio as gr | |
import replicate | |
import os | |
from huggingface_hub import InferenceClient | |
import random | |
import openai | |
# Set API tokens | |
os.environ["REPLICATE_API_TOKEN"] = "r8_8TlgofGX8rjeBL28vn0VBR93CWOUfvg4NbLS0" | |
# Initialize the Replicate client | |
rep_client = replicate.Client() | |
# Set your OpenAI API key | |
OPENAI_API_KEY = "sk-proj-5iy4bwrqAW8GpguiEawaT3BlbkFJ8p88lLSjOCeDbxWsAOlr" | |
openai.api_key = OPENAI_API_KEY | |
# Initialize the Replicate client | |
rep_client = replicate.Client() | |
# Predefined prompts for the dropdown | |
predefined_prompts = [ | |
"Missing bolts on railway track", | |
"Cracks on railway track", | |
"Overgrown vegetation near railway track", | |
"Broken railings on railway bridge", | |
"Debris on railway track", | |
"Damaged railway platform" | |
] | |
def ask_rail_defect_question(question, model_name='ft:gpt-3.5-turbo-0125:personal::99NsSAeQ'): | |
openai.api_key = OPENAI_API_KEY | |
response = openai.ChatCompletion.create( | |
model=model_name, | |
messages=[ | |
{ | |
"role": "system", | |
"content": "The assistant is knowledgeable about rail defects and can answer questions related to them.", | |
}, | |
{ | |
"role": "user", | |
"content": question, | |
} | |
], | |
) | |
return response.choices[0].message['content'] | |
# Function to generate variations enhanced by the GPT model | |
def generate_variations(base_prompt, number_of_variations): | |
locations = ["on the left side", "on the right side", "at the top", "at the bottom", "in the center"] | |
sizes = ["small", "medium", "large", "tiny", "huge"] | |
weather_conditions = ["under cold conditions", "during hot weather", "in dry weather", "in humid conditions", "under varying temperatures"] | |
variations = [] | |
for _ in range(number_of_variations): | |
location = random.choice(locations) | |
size = random.choice(sizes) | |
weather = random.choice(weather_conditions) | |
# Enhance the base prompt with the GPT model | |
enhanced_prompt = ask_rail_defect_question(base_prompt) | |
full_prompt = f"{enhanced_prompt}, with a {size} defect {location}, observed {weather}." | |
variations.append(full_prompt) | |
return variations | |
# Function to generate images from prompts | |
def generate_images(prompts): | |
images = [] | |
for prompt in prompts: | |
try: | |
prediction = rep_client.predictions.create( | |
version="ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4", | |
input={"prompt": prompt, "scheduler": "K_EULER"} | |
) | |
prediction.wait() | |
if prediction.status == "succeeded" and prediction.output: | |
images.append(prediction.output[0]) | |
else: | |
images.append("Failed to generate image.") | |
except Exception as e: | |
images.append(f"Error: {str(e)}") | |
return images | |
def process_railway_defects(prompt, number_of_images): | |
variations = generate_variations(prompt, number_of_images) | |
images = generate_images(variations) | |
return images | |
# UI creation | |
with gr.Blocks() as app: | |
with gr.Tabs("Prompt Input"): | |
with gr.Tab("Current Defects"): | |
with gr.Row(): | |
prompt_input = gr.Dropdown(choices=predefined_prompts, label="Select a prompt") | |
number_input_dropdown = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10) | |
submit_button_dropdown = gr.Button("Generate") | |
image_outputs_dropdown = gr.Gallery() | |
def on_submit_click_dropdown(prompt, number_of_images): | |
images = process_railway_defects(prompt, number_of_images) | |
return images | |
submit_button_dropdown.click( | |
fn=on_submit_click_dropdown, | |
inputs=[prompt_input, number_input_dropdown], | |
outputs=image_outputs_dropdown | |
) | |
with gr.Tab("Custom Defect"): | |
with gr.Row(): | |
custom_prompt_input = gr.Textbox(label="Custom Defect") | |
number_input_custom = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10) | |
submit_button_custom = gr.Button("Generate") | |
image_outputs_custom = gr.Gallery() | |
def on_submit_click_custom(custom_prompt, number_of_images): | |
images = process_railway_defects(custom_prompt, number_of_images) | |
return images | |
submit_button_custom.click( | |
fn=on_submit_click_custom, | |
inputs=[custom_prompt_input, number_input_custom], | |
outputs=image_outputs_custom | |
) | |
if __name__ == "__main__": | |
app.launch() |