import torch from PIL import Image from pathlib import Path import gradio as gr from transformers import CLIPProcessor, CLIPModel from torchvision import transforms import reverse_geocoder as rg import folium from geopy.exc import GeocoderTimedOut from geopy.geocoders import Nominatim # streetclip_model = CLIPModel.from_pretrained("E:/github projects/Country Classification/GeolocationCountryClassification/") model = CLIPModel.from_pretrained("geolocal/StreetCLIP") processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay'] def create_map(lat, lon): m = folium.Map(location=[lat, lon], zoom_start=4) folium.Marker([lat, lon]).add_to(m) map_html = m._repr_html_() return map_html geolocator = Nominatim(user_agent="predictGeolocforImage") def get_country_coordinates(country_name): try: location = geolocator.geocode(country_name, timeout=10) if location: return location.latitude, location.longitude except GeocoderTimedOut: return None return None def classify_streetclip(image): inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image prediction = logits_per_image.softmax(dim=1) confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True) top_label, top_confidence = sorted_confidences[0] coords = get_country_coordinates(top_label) map_html = create_map(*coords) if coords else "Map not available" return f"Country: {top_label}", map_html, confidences text = ''' List of countries supported: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay

---
You may choose to use the images provided below, or feel free to upload your own images. ''' interface = gr.Interface( fn=classify_streetclip, inputs=gr.Image(type="pil", label="Upload Image", elem_id="image_input"), outputs=[gr.Textbox(label="Prediction", elem_id="output"), gr.HTML(label="Map", elem_id="map_output"), gr.Label(num_top_classes=10,label="Top 10 countries")], title="COUNTRY GUESSER", description=text, article="Model is not running on a GPU, so the interpretation takes some time. Thank you for your patience🙏🏻", examples=["taj.jpg","stockholm.jpeg","palace-square-saint-petersburg.jpg","monument.jpg"], allow_flagging="never", ) interface.launch()