import gradio as gr import torch from PIL import Image from transformers import CLIPModel, CLIPProcessor LIST_LABELS = ['agricultural land', 'airplane', 'baseball diamond', 'beach', 'buildings', 'chaparral', 'dense residential area', 'forest', 'freeway', 'golf course', 'harbor', 'intersection', 'medium residential area', 'mobilehome park', 'overpass', 'parking lot', 'river', 'runway', 'sparse residential area', 'storage tanks', 'tennis court'] CLIP_LABELS = [f"A satellite image of {label}" for label in LIST_LABELS] MODEL_NAME = "NemesisAlm/clip-fine-tuned-satellite" device = "cuda" if torch.cuda.is_available() else "cpu" clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") fine_tuned_model = CLIPModel.from_pretrained(MODEL_NAME).to(device) fine_tuned_processor = CLIPProcessor.from_pretrained(MODEL_NAME) def classify(image_path, model_number): if model_number == "CLIP": processor = clip_processor model = clip_model else: processor = fine_tuned_processor model = fine_tuned_model image = Image.open(image_path).convert('RGB') inputs = processor(text=CLIP_LABELS, images=image, return_tensors="pt", padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image prediction = logits_per_image.softmax(dim=1) confidences = {LIST_LABELS[i]: float(prediction[0][i].item()) for i in range(len(LIST_LABELS))} return confidences DESCRIPTION="""

CLIP Fine-Tuned Satellite Model Demo

This space demonstrates the capabilities of a fine-tuned CLIP-based model in classifying satellite images. The model has been specifically trained on the UC Merced satellite image dataset.

After just 2 epochs of training, adjusting only 30% of the model parameters, the model's accuracy in classifying satellite images has significantly improved, from an initial accuracy of 58.8% to 96.9% on the test set.

Explore this space to see its performance and compare it with the initial CLIP model.

""" FOOTER = """
Link to model: https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite
Link to dataset: https://huggingface.co/datasets/blanchon/UC_Merced
""" with gr.Blocks(title="Satellite image classification", css="") as demo: logo = gr.HTML("") description = gr.HTML(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(type='filepath', label='Input image') submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): title_1 = gr.HTML("

Original CLIP Model

") model_1 = gr.Textbox("CLIP", visible=False) output_labels_clip = gr.Label(num_top_classes=10, label="Top 10 classes") with gr.Column(): title_2 = gr.HTML("

Fine-tuned Model

") model_2 = gr.Textbox("Fine-tuned", visible=False) output_labels_finetuned = gr.Label(num_top_classes=10, label="Top 10 classes") examples = gr.Examples([["0.jpg"], ["1.jpg"], ["2.jpg"], ["3.jpg"] ], input_image) footer = gr.HTML(FOOTER) submit_btn.click(fn=classify, inputs=[input_image, model_1], outputs=output_labels_clip).then( classify, inputs=[input_image, model_2], outputs=[output_labels_finetuned] ) demo.queue() demo.launch(server_name="0.0.0.0",favicon_path='favicon.ico', allowed_paths=["logo_gradio.png", "0.jpg", "1.jpg", "2.jpg", "3.jpg"])