country-guesser / app.py
shreyas2509's picture
Create app.py
f828007 verified
import torch
from PIL import Image
from pathlib import Path
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
import reverse_geocoder as rg
import folium
from geopy.exc import GeocoderTimedOut
from geopy.geocoders import Nominatim
# streetclip_model = CLIPModel.from_pretrained("E:/github projects/Country Classification/GeolocationCountryClassification/")
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
def create_map(lat, lon):
m = folium.Map(location=[lat, lon], zoom_start=4)
folium.Marker([lat, lon]).add_to(m)
map_html = m._repr_html_()
return map_html
geolocator = Nominatim(user_agent="predictGeolocforImage")
def get_country_coordinates(country_name):
try:
location = geolocator.geocode(country_name, timeout=10)
if location:
return location.latitude, location.longitude
except GeocoderTimedOut:
return None
return None
def classify_streetclip(image):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
top_label, top_confidence = sorted_confidences[0]
coords = get_country_coordinates(top_label)
map_html = create_map(*coords) if coords else "Map not available"
return f"Country: {top_label}", map_html, confidences
text = '''
<b style="color: #F36912;">List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay
</p>
---<br>
<span style="color: #F24F13;">You may choose to use the images provided below, or feel free to upload your own images.</span>
'''
interface = gr.Interface(
fn=classify_streetclip,
inputs=gr.Image(type="pil", label="Upload Image", elem_id="image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="output"), gr.HTML(label="Map", elem_id="map_output"), gr.Label(num_top_classes=10,label="Top 10 countries")],
title="COUNTRY GUESSER",
description=text,
article="<span style='color: #F24F13;'>Model is not running on a GPU, so the interpretation takes some time. Thank you for your patience🙏🏻</span>",
examples=["taj.jpg","stockholm.jpeg","palace-square-saint-petersburg.jpg","monument.jpg"],
allow_flagging="never",
)
interface.launch()