SkalskiP commited on
Commit
153394d
·
1 Parent(s): 2619d65
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -4,9 +4,11 @@ import gradio as gr
4
  import numpy as np
5
  import torch
6
  from transformers import CLIPProcessor, CLIPModel
 
7
 
8
  IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
9
  EXAMPLES = ["dog.jpeg", "car.png"]
 
10
 
11
  MARKDOWN = """
12
  # Zero-Shot Image Classification with MetaCLIP
@@ -24,6 +26,18 @@ def load_text_lines(file_path: str) -> List[str]:
24
  return [line.rstrip() for line in lines]
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device)
29
  processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
@@ -33,7 +47,7 @@ imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE)
33
  def classify_image(input_image) -> str:
34
  inputs = processor(
35
  text=imagenet_classes,
36
- images=input_image,
37
  return_tensors="pt",
38
  padding=True).to(device)
39
  outputs = model(**inputs)
 
4
  import numpy as np
5
  import torch
6
  from transformers import CLIPProcessor, CLIPModel
7
+ from PIL import Image
8
 
9
  IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
10
  EXAMPLES = ["dog.jpeg", "car.png"]
11
+ RESIZED_IMAGE_SIZE = 640
12
 
13
  MARKDOWN = """
14
  # Zero-Shot Image Classification with MetaCLIP
 
26
  return [line.rstrip() for line in lines]
27
 
28
 
29
+ def resize_image(input_image):
30
+ aspect_ratio = input_image.width / input_image.height
31
+ if input_image.width > input_image.height:
32
+ new_width = RESIZED_IMAGE_SIZE
33
+ new_height = int(RESIZED_IMAGE_SIZE / aspect_ratio)
34
+ else:
35
+ new_height = RESIZED_IMAGE_SIZE
36
+ new_width = int(RESIZED_IMAGE_SIZE * aspect_ratio)
37
+
38
+ return input_image.resize((new_width, new_height), Image.LANCZOS)
39
+
40
+
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device)
43
  processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
 
47
  def classify_image(input_image) -> str:
48
  inputs = processor(
49
  text=imagenet_classes,
50
+ images=resize_image(input_image),
51
  return_tensors="pt",
52
  padding=True).to(device)
53
  outputs = model(**inputs)