3v324v23 commited on
Commit
b67d94e
1 Parent(s): 8ef4f78

fixed a small bug on mask

Browse files
Files changed (1) hide show
  1. lib/model_zoo/clip.py +1 -24
lib/model_zoo/clip.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  import numpy as np
4
  from functools import partial
5
  from lib.model_zoo.common.get_model import register
 
6
 
7
  symbol = 'clip'
8
 
@@ -104,7 +105,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
104
  assert isinstance(masks, torch.Tensor)
105
  assert (len(masks.shape)==4) and (masks.shape[1]==1)
106
  masks = torch.clamp(masks, 0, 1)
107
- masked_images = images*masks
108
  masks = masks.float()
109
  masks = F.interpolate(masks, [224, 224], mode='bilinear')
110
  if masks.sum() == masks.numel():
@@ -142,29 +142,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
142
  z = z * vtoken_mask.to(dtype)
143
  return z
144
 
145
- # def _encode_wmask(self, images, masks):
146
- # assert isinstance(masks, torch.Tensor)
147
- # assert (len(masks.shape)==4) and (masks.shape[1]==1)
148
- # masks = torch.clamp(masks, 0, 1)
149
- # masks = masks.float()
150
- # masks = F.interpolate(masks, [224, 224], mode='bilinear')
151
- # if masks.sum() == masks.numel():
152
- # return self._encode(images)
153
-
154
- # device = images.device
155
- # dtype = images.dtype
156
-
157
- # vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
158
- # vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
159
- # mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
160
- # vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
161
- # vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
162
-
163
- # z = self._encode(images)
164
- # z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
165
- # z[:, 0, :] = 0
166
- # return z
167
-
168
  def encode(self, images, masks=None):
169
  if masks is None:
170
  return self._encode(images)
 
3
  import numpy as np
4
  from functools import partial
5
  from lib.model_zoo.common.get_model import register
6
+ import torch.nn.functional as F
7
 
8
  symbol = 'clip'
9
 
 
105
  assert isinstance(masks, torch.Tensor)
106
  assert (len(masks.shape)==4) and (masks.shape[1]==1)
107
  masks = torch.clamp(masks, 0, 1)
 
108
  masks = masks.float()
109
  masks = F.interpolate(masks, [224, 224], mode='bilinear')
110
  if masks.sum() == masks.numel():
 
142
  z = z * vtoken_mask.to(dtype)
143
  return z
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def encode(self, images, masks=None):
146
  if masks is None:
147
  return self._encode(images)