riyadifirman commited on
Commit
20115ca
1 Parent(s): b5e285e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -5,6 +5,10 @@ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomH
5
  from PIL import Image
6
  import traceback
7
 
 
 
 
 
8
  # Load model and processor
9
  model_name = "riyadifirman/klasifikasiburung"
10
  processor = AutoImageProcessor.from_pretrained(model_name)
@@ -27,7 +31,8 @@ def predict(image):
27
  outputs = model(inputs)
28
  logits = outputs.logits
29
  predicted_class_idx = logits.argmax(-1).item()
30
- return processor.decode(predicted_class_idx)
 
31
  except Exception as e:
32
  # Menampilkan error
33
  print("An error occurred:", e)
@@ -44,4 +49,4 @@ interface = gr.Interface(
44
  )
45
 
46
  if __name__ == "__main__":
47
- interface.launch()
 
5
  from PIL import Image
6
  import traceback
7
 
8
+ # Load dataset to get labels
9
+ dataset = load_dataset("bentrevett/caltech-ucsd-birds-200-2011")
10
+ labels = dataset['train'].features['label'].names
11
+
12
  # Load model and processor
13
  model_name = "riyadifirman/klasifikasiburung"
14
  processor = AutoImageProcessor.from_pretrained(model_name)
 
31
  outputs = model(inputs)
32
  logits = outputs.logits
33
  predicted_class_idx = logits.argmax(-1).item()
34
+ predicted_class = labels[predicted_class_idx]
35
+ return predicted_class
36
  except Exception as e:
37
  # Menampilkan error
38
  print("An error occurred:", e)
 
49
  )
50
 
51
  if __name__ == "__main__":
52
+ interface.launch(share=True)