brxerq's picture
Update app.py
9ea3afa verified
# app.py
import gradio as gr
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.models import load_model
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Specify the custom objects for loading the model
custom_objects = {'KerasLayer': hub.KerasLayer}
# Try loading the model with custom objects
try:
model = load_model('bird_model4.h5', custom_objects=custom_objects)
except ValueError as e:
print("Model loading failed with error:", e)
print("Please ensure the model was saved correctly and matches the KerasLayer structure.")
exit(1)
# Load class labels from your text file
train_info = []
with open('labelwithspace.txt', 'r') as file:
train_info = [line.strip() for line in file.read().splitlines()]
# Function to preprocess the input image
def preprocess_image(image):
img = cv2.resize(image, (224, 224))
img = img / 255.0 # Normalize
return img
# Prediction function
def predict_image(image):
img = preprocess_image(image)
img = np.expand_dims(img, axis=0)
predictions = model.predict(img)[0]
top_class = np.argmax(predictions)
label = train_info[top_class]
return label
# Gradio interface
input_image = gr.Image(shape=(224, 224), label="Input Image")
output_label = gr.Label(label="Predicted Bird Species")
# Launch Gradio
gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label, capture_session=True).launch()