curt-park commited on
Commit
8d4a5a4
·
1 Parent(s): f17c02a
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -57,8 +57,8 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
57
  txt_features = model.encode_text(token)
58
  img_features /= img_features.norm(dim=-1, keepdim=True)
59
  txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
- probs = 100.0 * img_features @ txt_features.T
61
- return probs[:, 0].softmax(dim=0)
62
 
63
 
64
  def filter_masks(
@@ -82,7 +82,7 @@ def filter_masks(
82
  filtered_masks.append(mask)
83
 
84
  x, y, w, h = mask["bbox"]
85
- crop = image[y : y + h, x : x + w]
86
  crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
  crop = PIL.Image.fromarray(np.uint8(crop * 255)).convert("RGB")
88
  crop.resize((CLIP_WIDTH, CLIP_HEIGHT))
 
57
  txt_features = model.encode_text(token)
58
  img_features /= img_features.norm(dim=-1, keepdim=True)
59
  txt_features /= txt_features.norm(dim=-1, keepdim=True)
60
+ similarity = (100.0 * img_features @ txt_features.T).softmax(dim=0)
61
+ return similarity
62
 
63
 
64
  def filter_masks(
 
82
  filtered_masks.append(mask)
83
 
84
  x, y, w, h = mask["bbox"]
85
+ crop = image[y: y + h, x: x + w]
86
  crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
  crop = PIL.Image.fromarray(np.uint8(crop * 255)).convert("RGB")
88
  crop.resize((CLIP_WIDTH, CLIP_HEIGHT))