File size: 3,691 Bytes
401f7f3 ee9ee13 401f7f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
def classify(image):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
return confidences
DESCRIPTION = """
<img src='file/logo.jpg' alt='logo' style='width:100px;float:left'/><h1 style='padding-top:30px'> StreetClip Model used for classification - 92 countries supported</h1>
<h2 style='margin-top:50px'> This Space demonstrates how the StreetClip model (https://huggingface.co/geolocal/StreetCLIP) can be used to geolocate images (classification per country).</h2>
<p>
π <b>List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay
</p>
---<br>
As a derivate work of [StreetClip] (https://huggingface.co/geolocal/StreetCLIP),this demo is governed by the original Create Commons Attribution Non Commercial 4.0.
"""
image = gr.Image(label="Input image")
label = gr.Label(num_top_classes=10,label="Top 10 countries")
demo = gr.Interface(
fn=classify,
inputs=image,
outputs=label,
interpretation="default",
title="StreetClip Classification",
description=DESCRIPTION,
article="Interpretation requires time to process, please be patient ππ»",
examples=["turkey.jpg","croatia.jpg"],
allow_flagging="never",
)
demo.launch(favicon_path='favicon.ico')
|