brxerq commited on
Commit
7bd873d
1 Parent(s): 992d073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -1,38 +1,36 @@
1
  # -*- coding: utf-8 -*-
2
-
3
-
4
  import gradio as gr
5
  import numpy as np
 
6
  import tensorflow_hub as hub
7
- from tensorflow.keras.models import load_model
8
  import cv2
9
 
10
  # Define a dictionary to map the custom layer to its implementation
11
  custom_objects = {'KerasLayer': hub.KerasLayer}
12
 
13
- # Load your model (ensure the path is correct)
14
- model = load_model('bird_model4.h5', custom_objects=custom_objects)
15
-
16
- # Define your class labels or categories for predictions
17
- train_info = [] # Replace with your actual class labels
18
 
19
- # Read image names from the text file
 
20
  with open('labelwithspace.txt', 'r') as file:
21
- train_info = [line.strip() for line in file.read().splitlines()]
22
-
23
 
 
24
  def predict_image(image):
 
25
  img = cv2.resize(image, (224, 224))
26
- img = img / 255.0
 
27
  predictions = model.predict(img[np.newaxis, ...])[0]
28
- top_classes = np.argsort(predictions)[-3:][::-1]
29
- top_class = top_classes[0] # Get the index of the top prediction
30
- label = train_info[top_class] # Use the index to retrieve the label
31
  return label
32
 
33
-
34
- # Define Gradio interface
35
  input_image = gr.inputs.Image(shape=(224, 224))
36
  output_label = gr.outputs.Label()
37
 
38
- gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label, capture_session=True).launch()
 
 
1
  # -*- coding: utf-8 -*-
 
 
2
  import gradio as gr
3
  import numpy as np
4
+ import tensorflow as tf
5
  import tensorflow_hub as hub
 
6
  import cv2
7
 
8
  # Define a dictionary to map the custom layer to its implementation
9
  custom_objects = {'KerasLayer': hub.KerasLayer}
10
 
11
+ # Load your model (ensure the path to the model file is correct)
12
+ model = tf.keras.models.load_model('bird_model4.h5', custom_objects=custom_objects)
 
 
 
13
 
14
+ # Load the class labels from the text file
15
+ train_info = []
16
  with open('labelwithspace.txt', 'r') as file:
17
+ train_info = [line.strip() for line in file.readlines()]
 
18
 
19
+ # Function to preprocess the image and make predictions
20
  def predict_image(image):
21
+ # Resize the image to the model's input size
22
  img = cv2.resize(image, (224, 224))
23
+ img = img / 255.0 # Normalize the image
24
+ # Make predictions using the model
25
  predictions = model.predict(img[np.newaxis, ...])[0]
26
+ top_classes = np.argsort(predictions)[-3:][::-1] # Get indices of top 3 predictions
27
+ top_class = top_classes[0] # Index of the top prediction
28
+ label = train_info[top_class] # Retrieve the label using the index
29
  return label
30
 
31
+ # Define the Gradio interface
 
32
  input_image = gr.inputs.Image(shape=(224, 224))
33
  output_label = gr.outputs.Label()
34
 
35
+ # Launch the Gradio app
36
+ gr.Interface(fn=predict_image, inputs=input_image, outputs=output_label, capture_session=True).launch()