Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,116 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Check if flash_attn is available
         | 
| 8 | 
            +
            def is_flash_attn_available():
         | 
| 9 | 
            +
                try:
         | 
| 10 | 
            +
                    import flash_attn
         | 
| 11 | 
            +
                    return True
         | 
| 12 | 
            +
                except ImportError:
         | 
| 13 | 
            +
                    return False
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Load model and tokenizer
         | 
| 16 | 
            +
            @torch.inference_mode()
         | 
| 17 | 
            +
            def load_model():
         | 
| 18 | 
            +
                use_optimized = torch.cuda.is_available() and is_flash_attn_available()
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                model = AutoModel.from_pretrained(
         | 
| 21 | 
            +
                    "visheratin/mexma-siglip2", 
         | 
| 22 | 
            +
                    torch_dtype=torch.bfloat16, 
         | 
| 23 | 
            +
                    trust_remote_code=True, 
         | 
| 24 | 
            +
                    optimized=True if use_optimized else False,
         | 
| 25 | 
            +
                )
         | 
| 26 | 
            +
                if torch.cuda.is_available():
         | 
| 27 | 
            +
                    model = model.to("cuda")
         | 
| 28 | 
            +
                tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2")
         | 
| 29 | 
            +
                processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2")
         | 
| 30 | 
            +
                return model, tokenizer, processor
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            model, tokenizer, processor = load_model()
         | 
| 33 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def classify_image(image, text_queries):
         | 
| 36 | 
            +
                if image is None or not text_queries.strip():
         | 
| 37 | 
            +
                    return None
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                # Process image
         | 
| 40 | 
            +
                processed_image = processor(images=image, return_tensors="pt")["pixel_values"]
         | 
| 41 | 
            +
                processed_image = processed_image.to(torch.bfloat16)
         | 
| 42 | 
            +
                if torch.cuda.is_available():
         | 
| 43 | 
            +
                    processed_image = processed_image.to("cuda")
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                # Process text queries
         | 
| 46 | 
            +
                queries = [q.strip() for q in text_queries.split("\n") if q.strip()]
         | 
| 47 | 
            +
                if not queries:
         | 
| 48 | 
            +
                    return None
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                text_inputs = tokenizer(queries, return_tensors="pt", padding=True)
         | 
| 51 | 
            +
                if torch.cuda.is_available():
         | 
| 52 | 
            +
                    text_inputs = text_inputs.to("cuda")
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                # Get predictions
         | 
| 55 | 
            +
                with torch.inference_mode():
         | 
| 56 | 
            +
                    image_logits, _ = model.get_logits(
         | 
| 57 | 
            +
                        text_inputs["input_ids"], 
         | 
| 58 | 
            +
                        text_inputs["attention_mask"], 
         | 
| 59 | 
            +
                        processed_image
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist()
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                # Format results
         | 
| 64 | 
            +
                results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))}
         | 
| 65 | 
            +
                return results
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            # Create Gradio interface
         | 
| 68 | 
            +
            with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo:
         | 
| 69 | 
            +
                gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo")
         | 
| 70 | 
            +
                gr.Markdown("""
         | 
| 71 | 
            +
                This demo showcases the zero-shot classification capabilities of the Mexma-SigLIP2 model.
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                ### Instructions:
         | 
| 74 | 
            +
                1. Upload or select an image
         | 
| 75 | 
            +
                2. Enter text queries (one per line) to classify the image
         | 
| 76 | 
            +
                3. Click 'Submit' to see the classification probabilities
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                The model supports multilingual queries (English, Russian, Hindi, etc.)
         | 
| 79 | 
            +
                """)
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                with gr.Row():
         | 
| 82 | 
            +
                    with gr.Column():
         | 
| 83 | 
            +
                        image_input = gr.Image(type="pil", label="Upload Image")
         | 
| 84 | 
            +
                        text_input = gr.Textbox(
         | 
| 85 | 
            +
                            placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower", 
         | 
| 86 | 
            +
                            label="Text Queries",
         | 
| 87 | 
            +
                            lines=5
         | 
| 88 | 
            +
                        )
         | 
| 89 | 
            +
                        submit_btn = gr.Button("Submit", variant="primary")
         | 
| 90 | 
            +
                    
         | 
| 91 | 
            +
                    with gr.Column():
         | 
| 92 | 
            +
                        output = gr.Label(label="Classification Results")
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                submit_btn.click(
         | 
| 95 | 
            +
                    fn=classify_image,
         | 
| 96 | 
            +
                    inputs=[image_input, text_input],
         | 
| 97 | 
            +
                    outputs=output
         | 
| 98 | 
            +
                )
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                gr.Examples(
         | 
| 101 | 
            +
                    [
         | 
| 102 | 
            +
                        [
         | 
| 103 | 
            +
                            "https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg",
         | 
| 104 | 
            +
                            "Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर"
         | 
| 105 | 
            +
                        ],
         | 
| 106 | 
            +
                        [
         | 
| 107 | 
            +
                            "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
         | 
| 108 | 
            +
                            "a cat\na dog\na bird\nкошка\nсобака"
         | 
| 109 | 
            +
                        ]
         | 
| 110 | 
            +
                    ],
         | 
| 111 | 
            +
                    inputs=[image_input, text_input]
         | 
| 112 | 
            +
                )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # Launch the demo
         | 
| 115 | 
            +
            if __name__ == "__main__":
         | 
| 116 | 
            +
                demo.launch()
         |